第一次提交
This commit is contained in:
194
service/auto_tune.go
Normal file
194
service/auto_tune.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"model-asynch/consts/public"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// AutoTuneResult 单次调参结果(按 model_name)
|
||||
type AutoTuneResult struct {
|
||||
ModelName string `json:"modelName"` // 模型名称(asynch_models.model_name)
|
||||
Samples int `json:"samples"` // 统计样本数(窗口内 state=2/3 且 started_at/finished_at 非空的任务数量)
|
||||
P90Exec float64 `json:"p90ExecSeconds"` // 执行耗时 P90(秒),口径:finished_at - started_at
|
||||
|
||||
CapMaxConcurrency int `json:"capMaxConcurrency"` // 配置上限:asynch_models.max_concurrency(cap,不会被动态调参覆盖)
|
||||
OldMaxConcurrency int `json:"oldMaxConcurrency"` // 调参前运行时值(Redis),若无则等于 cap
|
||||
NewMaxConcurrency int `json:"newMaxConcurrency"` // 本次计算出的运行时值(将写入 Redis),受 ±50% 约束且不超过 cap
|
||||
|
||||
CapQueueLimit int `json:"capQueueLimit"` // 配置上限:asynch_models.queue_limit(cap,不会被动态调参覆盖)
|
||||
OldQueueLimit int `json:"oldQueueLimit"` // 调参前运行时值(Redis),若无则等于 cap
|
||||
NewQueueLimit int `json:"newQueueLimit"` // 本次计算出的运行时值(将写入 Redis),受 ±50% 约束且不超过 cap
|
||||
|
||||
ExpectedSeconds int `json:"expectedSeconds"` // 模型预计执行时间(秒):asynch_models.expected_seconds(用于 queue_limit 计算绑定)
|
||||
}
|
||||
|
||||
// AutoTune 由上层定时任务通过接口触发:
|
||||
// - 统计指定时间窗口内该模型任务的执行耗时(finished_at - started_at,取 P90)
|
||||
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
||||
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
||||
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
||||
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 3600
|
||||
}
|
||||
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
||||
var modelRows []*entity.AsynchModel
|
||||
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where("deleted_at IS NULL").
|
||||
Where(entity.AsynchModelCol.Enabled, 1).
|
||||
Scan(&modelRows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
modelMap := make(map[string]*entity.AsynchModel)
|
||||
for _, m := range modelRows {
|
||||
if m == nil || m.ModelName == "" {
|
||||
continue
|
||||
}
|
||||
cur := modelMap[m.ModelName]
|
||||
if cur == nil {
|
||||
modelMap[m.ModelName] = m
|
||||
continue
|
||||
}
|
||||
// 取更大的 cap
|
||||
if m.MaxConcurrency > cur.MaxConcurrency {
|
||||
cur.MaxConcurrency = m.MaxConcurrency
|
||||
}
|
||||
if m.QueueLimit > cur.QueueLimit {
|
||||
cur.QueueLimit = m.QueueLimit
|
||||
}
|
||||
if m.ExpectedSeconds > cur.ExpectedSeconds {
|
||||
cur.ExpectedSeconds = m.ExpectedSeconds
|
||||
}
|
||||
}
|
||||
if len(modelMap) == 0 {
|
||||
return []AutoTuneResult{}, nil
|
||||
}
|
||||
|
||||
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
||||
type statRow struct {
|
||||
ModelName string
|
||||
Cnt int
|
||||
P90Exec float64
|
||||
}
|
||||
var stats []statRow
|
||||
sql := fmt.Sprintf(`
|
||||
SELECT model_name,
|
||||
COUNT(1) AS cnt,
|
||||
COALESCE(percentile_cont(0.9) WITHIN GROUP (ORDER BY EXTRACT(EPOCH FROM (finished_at - started_at))), 0) AS p90_exec
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL
|
||||
AND state IN (2,3)
|
||||
AND started_at IS NOT NULL
|
||||
AND finished_at IS NOT NULL
|
||||
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
||||
GROUP BY model_name`, public.TableNameTask)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = r.Structs(&stats)
|
||||
statMap := make(map[string]statRow, len(stats))
|
||||
for _, s := range stats {
|
||||
statMap[s.ModelName] = s
|
||||
}
|
||||
|
||||
// 3) 调参计算
|
||||
const utilization = 0.8
|
||||
const maxChangeRatio = 0.5 // ±50%
|
||||
const queueFactor = 2.0 // 与 expected_seconds 绑定:W_target = expected_seconds * 2
|
||||
|
||||
out := make([]AutoTuneResult, 0, len(modelMap))
|
||||
for modelName, m := range modelMap {
|
||||
s := statMap[modelName]
|
||||
capMax := m.MaxConcurrency
|
||||
capQueue := m.QueueLimit
|
||||
oldMax := GetRuntimeMaxConcurrency(ctx, modelName, capMax)
|
||||
oldQueue := GetRuntimeQueueLimit(ctx, modelName, capQueue)
|
||||
|
||||
// 默认:无样本则不调整
|
||||
if s.Cnt <= 0 || s.P90Exec <= 0 {
|
||||
out = append(out, AutoTuneResult{
|
||||
ModelName: modelName,
|
||||
Samples: s.Cnt,
|
||||
P90Exec: s.P90Exec,
|
||||
CapMaxConcurrency: capMax,
|
||||
OldMaxConcurrency: oldMax,
|
||||
NewMaxConcurrency: oldMax,
|
||||
CapQueueLimit: capQueue,
|
||||
OldQueueLimit: oldQueue,
|
||||
NewQueueLimit: oldQueue,
|
||||
ExpectedSeconds: m.ExpectedSeconds,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// arrival_rate ≈ 完成数/3600
|
||||
arrivalRate := float64(s.Cnt) / 3600.0
|
||||
|
||||
// desiredMax = ceil(arrivalRate * p90 / utilization)
|
||||
desiredMax := int(math.Ceil(arrivalRate * s.P90Exec / utilization))
|
||||
if desiredMax < 1 {
|
||||
desiredMax = 1
|
||||
}
|
||||
// 单次变化幅度限制
|
||||
minMax := int(math.Floor(float64(oldMax) * (1 - maxChangeRatio)))
|
||||
maxMax := int(math.Ceil(float64(oldMax) * (1 + maxChangeRatio)))
|
||||
if minMax < 1 {
|
||||
minMax = 1
|
||||
}
|
||||
newMax := clampInt(desiredMax, minMax, maxMax)
|
||||
if capMax > 0 {
|
||||
newMax = clampInt(newMax, 1, capMax)
|
||||
}
|
||||
setRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName), newMax)
|
||||
|
||||
// queue_limit:W_target = expected_seconds * queueFactor
|
||||
exp := m.ExpectedSeconds
|
||||
if exp <= 0 {
|
||||
exp = 60
|
||||
}
|
||||
wTarget := float64(exp) * queueFactor
|
||||
desiredQueue := int(math.Ceil(arrivalRate*wTarget)) + newMax
|
||||
if desiredQueue < newMax {
|
||||
desiredQueue = newMax
|
||||
}
|
||||
|
||||
newQueue := oldQueue
|
||||
if capQueue > 0 {
|
||||
minQ := int(math.Floor(float64(oldQueue) * (1 - maxChangeRatio)))
|
||||
maxQ := int(math.Ceil(float64(oldQueue) * (1 + maxChangeRatio)))
|
||||
if minQ < newMax {
|
||||
minQ = newMax
|
||||
}
|
||||
if maxQ < minQ {
|
||||
maxQ = minQ
|
||||
}
|
||||
newQueue = clampInt(desiredQueue, minQ, maxQ)
|
||||
newQueue = clampInt(newQueue, newMax, capQueue)
|
||||
setRuntimeInt(ctx, runtimeQueueLimitKey(modelName), newQueue)
|
||||
}
|
||||
|
||||
out = append(out, AutoTuneResult{
|
||||
ModelName: modelName,
|
||||
Samples: s.Cnt,
|
||||
P90Exec: s.P90Exec,
|
||||
CapMaxConcurrency: capMax,
|
||||
OldMaxConcurrency: oldMax,
|
||||
NewMaxConcurrency: newMax,
|
||||
CapQueueLimit: capQueue,
|
||||
OldQueueLimit: oldQueue,
|
||||
NewQueueLimit: newQueue,
|
||||
ExpectedSeconds: m.ExpectedSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
|
||||
return out, nil
|
||||
}
|
||||
88
service/callback.go
Normal file
88
service/callback.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// triggerSuccessCallback 任务成功后的回调钩子:
|
||||
// - 使用 GET 请求
|
||||
// - 回调地址为 callbackUrl + "/" + bizName
|
||||
// - query 参数:task_id/state/oss_file/file_type
|
||||
// 注意:回调失败不影响任务主流程,只记录日志。
|
||||
func triggerSuccessCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
callbackURL := strings.TrimSpace(t.CallbackURL)
|
||||
bizName := strings.TrimSpace(t.BizName)
|
||||
if callbackURL == "" || bizName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] invalid callbackUrl=%s err=%v", callbackURL, err)
|
||||
return
|
||||
}
|
||||
// 必须是可发起 HTTP 请求的绝对地址
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
g.Log().Warningf(ctx, "[callback] callbackUrl must be absolute http(s) url, got=%s", callbackURL)
|
||||
return
|
||||
}
|
||||
|
||||
// path 末尾拼接 bizName
|
||||
bizSeg := url.PathEscape(bizName)
|
||||
if strings.HasSuffix(u.Path, "/") || u.Path == "" {
|
||||
u.Path = strings.TrimRight(u.Path, "/") + "/" + bizSeg
|
||||
} else {
|
||||
u.Path = u.Path + "/" + bizSeg
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("task_id", t.TaskID)
|
||||
q.Set("state", fmt.Sprintf("%d", t.State))
|
||||
q.Set("oss_file", t.OssFile)
|
||||
q.Set("file_type", t.FileType)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] build request failed url=%s err=%v", u.String(), err)
|
||||
return
|
||||
}
|
||||
// 透传必要头部(如 Authorization / X-User-Info)
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] request failed url=%s err=%v", u.String(), err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
g.Log().Warningf(ctx, "[callback] non-2xx url=%s code=%d body=%s", u.String(), resp.StatusCode, msg)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[callback] success url=%s code=%d", u.String(), resp.StatusCode)
|
||||
}
|
||||
|
||||
92
service/cleaner.go
Normal file
92
service/cleaner.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
var Cleaner = &cleaner{}
|
||||
|
||||
type cleaner struct{}
|
||||
|
||||
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
||||
func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
}
|
||||
|
||||
// 2) 超时任务标失败
|
||||
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
|
||||
} else {
|
||||
for _, t := range list {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
|
||||
}
|
||||
|
||||
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
||||
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
||||
} else {
|
||||
for _, t := range retryable {
|
||||
// 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试
|
||||
// 获取模型配置以得到 queue_limit / expected_seconds
|
||||
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil || m == nil {
|
||||
continue
|
||||
}
|
||||
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// retry_queue_max_seconds 控制失败重试的排队策略:
|
||||
// - =0:失败重试插队到队首
|
||||
// - >0:当任务从创建到现在的排队时长 >= maxSeconds,则插队到队首;否则仍放到队尾
|
||||
now := time.Now()
|
||||
enqueueAt := now
|
||||
maxSeconds := t.RetryQueueMaxSeconds
|
||||
if maxSeconds == 0 {
|
||||
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
|
||||
} else if maxSeconds > 0 && t.CreatedAt != nil {
|
||||
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
|
||||
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
|
||||
}
|
||||
}
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
||||
}
|
||||
|
||||
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
|
||||
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
}
|
||||
}
|
||||
35
service/file_detect.go
Normal file
35
service/file_detect.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
return ct, "." + parts[1]
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
54
service/headers.go
Normal file
54
service/headers.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
185
service/model_invoker.go
Normal file
185
service/model_invoker.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func parseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
if payload == nil {
|
||||
return url.Values{}, nil
|
||||
}
|
||||
// 统一转成 map[string]any
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := map[string]any{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := url.Values{}
|
||||
for k, v := range m {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
// 复杂类型直接 json 字符串化
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
q.Set(k, vv)
|
||||
case float64, bool, int, int64, uint64:
|
||||
q.Set(k, fmt.Sprintf("%v", vv))
|
||||
default:
|
||||
bs, _ := json.Marshal(v)
|
||||
q.Set(k, string(bs))
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。
|
||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
||||
if strings.TrimSpace(m.Route) == "" {
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
}
|
||||
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(url, "?") {
|
||||
url = url + "&" + q.Encode()
|
||||
} else {
|
||||
url = url + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先注入模型配置 head_msg(静态头部)
|
||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
// 透传必要头部(如 Authorization / X-User-Info)
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
// 最后注入动态 modelKey(覆盖/补充静态 head_msg)
|
||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// 尽量把错误体带回去,方便排查
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
125
service/model_service.go
Normal file
125
service/model_service.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
m := &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelsType: normalizeModelsType(req.ModelsType),
|
||||
BaseURL: req.BaseURL,
|
||||
Route: req.Route,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSecs: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
if m.HttpMethod == "" {
|
||||
m.HttpMethod = "POST"
|
||||
}
|
||||
if m.Enabled == 0 {
|
||||
m.Enabled = 1
|
||||
}
|
||||
if m.MaxConcurrency <= 0 {
|
||||
m.MaxConcurrency = 10
|
||||
}
|
||||
if m.QueueLimit <= 0 {
|
||||
m.QueueLimit = 1000
|
||||
}
|
||||
if m.TimeoutSeconds <= 0 {
|
||||
m.TimeoutSeconds = 60
|
||||
}
|
||||
if m.AutoCleanSeconds <= 0 {
|
||||
m.AutoCleanSeconds = 86400
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
data := map[string]any{}
|
||||
if req.BaseURL != "" {
|
||||
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
|
||||
}
|
||||
if req.Route != "" {
|
||||
data[entity.AsynchModelCol.Route] = req.Route
|
||||
}
|
||||
if req.HttpMethod != nil && *req.HttpMethod != "" {
|
||||
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
|
||||
}
|
||||
if req.HeadMsg != nil {
|
||||
data[entity.AsynchModelCol.HeadMsg] = *req.HeadMsg
|
||||
}
|
||||
if req.Form != nil {
|
||||
data[entity.AsynchModelCol.FormJSON] = req.Form
|
||||
}
|
||||
if req.ModelsType != nil {
|
||||
data[entity.AsynchModelCol.ModelsType] = normalizeModelsType(*req.ModelsType)
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
data[entity.AsynchModelCol.Enabled] = *req.Enabled
|
||||
}
|
||||
if req.MaxConcurrency != nil {
|
||||
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
|
||||
}
|
||||
if req.QueueLimit != nil {
|
||||
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
|
||||
}
|
||||
if req.TimeoutSeconds != nil {
|
||||
data[entity.AsynchModelCol.TimeoutSeconds] = *req.TimeoutSeconds
|
||||
}
|
||||
if req.ExpectedSeconds != nil {
|
||||
data[entity.AsynchModelCol.ExpectedSeconds] = *req.ExpectedSeconds
|
||||
}
|
||||
if req.RetryTimes != nil {
|
||||
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
|
||||
}
|
||||
if req.RetryQueueMaxSeconds != nil {
|
||||
data[entity.AsynchModelCol.RetryQueueMaxSecs] = *req.RetryQueueMaxSeconds
|
||||
}
|
||||
if req.AutoCleanSeconds != nil {
|
||||
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
|
||||
}
|
||||
if req.Remark != nil {
|
||||
data[entity.AsynchModelCol.Remark] = *req.Remark
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id int64) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
return dao.Model.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
return dao.Model.List(ctx, pageNum, pageSize, modelNameLike)
|
||||
}
|
||||
217
service/model_type_service.go
Normal file
217
service/model_type_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
type modelTypeService struct{}
|
||||
|
||||
var ModelType = &modelTypeService{}
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Create(ctx context.Context, req *dto.CreateModelTypeReq) (res *dto.CreateModelTypeRes, err error) {
|
||||
t := &entity.AsynchModelType{
|
||||
TypeID: req.TypeID,
|
||||
TypeName: req.TypeName,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
id, err := dao.ModelType.Insert(ctx, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelTypeRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Update(ctx context.Context, req *dto.UpdateModelTypeReq) error {
|
||||
data := map[string]any{}
|
||||
if req.TypeID != nil {
|
||||
data[entity.AsynchModelTypeCol.TypeID] = *req.TypeID
|
||||
}
|
||||
if req.TypeName != nil {
|
||||
data[entity.AsynchModelTypeCol.TypeName] = *req.TypeName
|
||||
}
|
||||
if req.Remark != nil {
|
||||
data[entity.AsynchModelTypeCol.Remark] = *req.Remark
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.ModelType.UpdateByID(ctx, req.ID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Delete(ctx context.Context, id int64) error {
|
||||
_, err := dao.ModelType.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Get(ctx context.Context, id int64) (*entity.AsynchModelType, error) {
|
||||
return dao.ModelType.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *modelTypeService) List(ctx context.Context, pageNum, pageSize int, typeNameLike string) (list []*entity.AsynchModelType, total int64, err error) {
|
||||
return dao.ModelType.List(ctx, pageNum, pageSize, typeNameLike)
|
||||
}
|
||||
|
||||
// ListWithModels 按类型分组返回模型(返回数组,便于前端直接渲染)
|
||||
func (s *modelTypeService) ListWithModels(ctx context.Context, req *dto.ListModelTypeWithModelsReq) (res []dto.ModelTypeWithModelsItem, err error) {
|
||||
types, _, err := dao.ModelType.List(ctx, 1, 1000, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 过滤类型(按 typeId / typeName 模糊)
|
||||
filterTypeID := 0
|
||||
filterTypeName := ""
|
||||
if req != nil {
|
||||
filterTypeID = req.TypeID
|
||||
filterTypeName = strings.TrimSpace(req.Type)
|
||||
}
|
||||
typeIDs := make([]int, 0, len(types))
|
||||
typeNameMap := make(map[int]string, len(types))
|
||||
for _, t := range types {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if filterTypeID > 0 && t.TypeID != filterTypeID {
|
||||
continue
|
||||
}
|
||||
if filterTypeName != "" && !strings.Contains(t.TypeName, filterTypeName) {
|
||||
continue
|
||||
}
|
||||
typeIDs = append(typeIDs, t.TypeID)
|
||||
typeNameMap[t.TypeID] = t.TypeName
|
||||
}
|
||||
models, err := dao.Model.ListAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
itemsMap := map[int][]dto.ModelTypeModelItem{}
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
form := normalizeFormValue(m.Form)
|
||||
// 一个模型可能支持多个类型:models_type="1,2,3"
|
||||
for _, tid := range parseModelsTypeIDs(m.ModelsType) {
|
||||
// 若请求过滤了类型,则只输出该类型
|
||||
if filterTypeID > 0 && tid != filterTypeID {
|
||||
continue
|
||||
}
|
||||
if filterTypeName != "" {
|
||||
if _, ok := typeNameMap[tid]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
itemsMap[tid] = append(itemsMap[tid], dto.ModelTypeModelItem{
|
||||
ID: m.Id,
|
||||
Name: m.ModelName,
|
||||
Form: form,
|
||||
})
|
||||
}
|
||||
}
|
||||
out := make([]dto.ModelTypeWithModelsItem, 0, len(typeIDs))
|
||||
for _, tid := range typeIDs {
|
||||
items := itemsMap[tid]
|
||||
if items == nil {
|
||||
items = make([]dto.ModelTypeModelItem, 0)
|
||||
}
|
||||
out = append(out, dto.ModelTypeWithModelsItem{
|
||||
TypeID: tid,
|
||||
Type: typeNameMap[tid],
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
52
service/model_types_util.go
Normal file
52
service/model_types_util.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// normalizeModelsType 将 "1, 2,2,3" 归一化为 "1,2,3"
|
||||
// - 去空格
|
||||
// - 去重
|
||||
// - 升序排序
|
||||
func normalizeModelsType(v string) string {
|
||||
ids := parseModelsTypeIDs(v)
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
parts = append(parts, strconv.Itoa(id))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// parseModelsTypeIDs 解析 models_type 字段(支持 "1,2,3"),返回去重后的 int 列表(升序)。
|
||||
func parseModelsTypeIDs(v string) []int {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
raw := strings.Split(v, ",")
|
||||
seen := map[int]struct{}{}
|
||||
out := make([]int, 0, len(raw))
|
||||
for _, s := range raw {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" || s == "0" {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.Atoi(s)
|
||||
if err != nil || id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
sort.Ints(out)
|
||||
return out
|
||||
}
|
||||
|
||||
25
service/payload.go
Normal file
25
service/payload.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
107
service/queue_gate.go
Normal file
107
service/queue_gate.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ===== 严格 queue_limit:Redis 原子闸门 =====
|
||||
//
|
||||
// 背景:原来的 queue_limit 通过“Count + Insert”做近似控制,分布式并发创建时会短暂超限。
|
||||
// 目标:以 Redis Lua 脚本实现原子校验 + 入队占位,做到严格不超限。
|
||||
//
|
||||
// 计数口径与原逻辑保持一致:只统计 state=0/1(排队中/执行中)。
|
||||
// - CreateTask 成功入库后占用 1 个 slot
|
||||
// - 任务成功/失败(state->2/3)释放 slot
|
||||
// - 失败任务重试(state 3->0)需要再次占用 slot,若占位失败则暂不重试(留在 state=3,下次 cleaner 再尝试)
|
||||
//
|
||||
// 说明:为避免极端情况下“占位泄漏”导致永久占满,采用 ZSET + 过期时间的方式自动回收。
|
||||
// 只要任务实际生命周期远小于 gateTTLSeconds,就可保持严格。
|
||||
|
||||
const (
|
||||
queueGateKeyPrefix = "asynch:qgate:" // asynch:qgate:{modelName}
|
||||
)
|
||||
|
||||
// Lua:清理过期 slot,然后按 limit 做原子判定并占位
|
||||
var queueGateAcquireLua = `
|
||||
local key = KEYS[1]
|
||||
local now = tonumber(ARGV[1])
|
||||
local limit = tonumber(ARGV[2])
|
||||
local expireAt = tonumber(ARGV[3])
|
||||
local member = ARGV[4]
|
||||
local keyTTL = tonumber(ARGV[5])
|
||||
|
||||
-- 先清理过期的占位
|
||||
redis.call("ZREMRANGEBYSCORE", key, "-inf", now)
|
||||
|
||||
local current = tonumber(redis.call("ZCARD", key) or "0")
|
||||
if current >= limit then
|
||||
return 0
|
||||
end
|
||||
redis.call("ZADD", key, expireAt, member)
|
||||
redis.call("EXPIRE", key, keyTTL)
|
||||
return 1
|
||||
`
|
||||
|
||||
// Lua:释放 slot(幂等)
|
||||
var queueGateReleaseLua = `
|
||||
local key = KEYS[1]
|
||||
local member = ARGV[1]
|
||||
redis.call("ZREM", key, member)
|
||||
return 1
|
||||
`
|
||||
|
||||
func queueGateKey(modelName string) string {
|
||||
return fmt.Sprintf("%s%s", queueGateKeyPrefix, modelName)
|
||||
}
|
||||
|
||||
// calcGateTTLSeconds 计算闸门占位的“自动回收 TTL”
|
||||
// 取 expectedSeconds 的倍数并做上下限,避免任务异常导致永久占位。
|
||||
func calcGateTTLSeconds(expectedSeconds int) int {
|
||||
// 默认至少 1 小时;最多 24 小时
|
||||
minTTL := 3600
|
||||
maxTTL := 24 * 3600
|
||||
if expectedSeconds <= 0 {
|
||||
return minTTL
|
||||
}
|
||||
ttl := int(math.Ceil(float64(expectedSeconds) * 10)) // 预计耗时 * 10 做兜底
|
||||
if ttl < minTTL {
|
||||
ttl = minTTL
|
||||
}
|
||||
if ttl > maxTTL {
|
||||
ttl = maxTTL
|
||||
}
|
||||
return ttl
|
||||
}
|
||||
|
||||
// AcquireQueueSlot 严格入队:原子占位(成功返回 true)
|
||||
func AcquireQueueSlot(ctx context.Context, modelName, taskId string, limit int, expectedSeconds int) (bool, error) {
|
||||
if limit <= 0 {
|
||||
return true, nil
|
||||
}
|
||||
key := queueGateKey(modelName)
|
||||
now := time.Now().Unix()
|
||||
ttl := calcGateTTLSeconds(expectedSeconds)
|
||||
expireAt := now + int64(ttl)
|
||||
// keyTTL 要略大于 member TTL,避免 key 先过期导致计数丢失
|
||||
keyTTL := ttl + 60
|
||||
r, err := g.Redis().Do(ctx, "EVAL", queueGateAcquireLua, 1, key, now, limit, expireAt, taskId, keyTTL)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("queue gate acquire failed: %w", err)
|
||||
}
|
||||
return gconv.Int(r) == 1, nil
|
||||
}
|
||||
|
||||
// ReleaseQueueSlot 释放占位(幂等)
|
||||
func ReleaseQueueSlot(ctx context.Context, modelName, taskId string) {
|
||||
if taskId == "" || modelName == "" {
|
||||
return
|
||||
}
|
||||
key := queueGateKey(modelName)
|
||||
_, _ = g.Redis().Do(ctx, "EVAL", queueGateReleaseLua, 1, key, taskId)
|
||||
}
|
||||
83
service/runtime_tune.go
Normal file
83
service/runtime_tune.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// 运行时调参存储在 Redis,不修改 asynch_models 中的 cap(最大上限)。
|
||||
// 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。
|
||||
|
||||
const (
|
||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||
)
|
||||
|
||||
func runtimeMaxConcurrencyKey(modelName string) string {
|
||||
return runtimeMaxCKeyPrefix + modelName
|
||||
}
|
||||
func runtimeQueueLimitKey(modelName string) string {
|
||||
return runtimeQueueKeyPrefix + modelName
|
||||
}
|
||||
|
||||
func getRuntimeInt(ctx context.Context, key string) (int, bool) {
|
||||
v, err := g.Redis().Do(ctx, "GET", key)
|
||||
if err != nil || v == nil {
|
||||
return 0, false
|
||||
}
|
||||
iv := gconv.Int(v)
|
||||
if iv <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
return iv, true
|
||||
}
|
||||
|
||||
func setRuntimeInt(ctx context.Context, key string, val int) {
|
||||
if val <= 0 {
|
||||
return
|
||||
}
|
||||
// SETEX key ttl val
|
||||
_, _ = g.Redis().Do(ctx, "SETEX", key, runtimeTTLSeconds, val)
|
||||
}
|
||||
|
||||
// GetRuntimeMaxConcurrency 返回运行时并发上限(<= cap)。若不存在运行时值,则返回 cap。
|
||||
func GetRuntimeMaxConcurrency(ctx context.Context, modelName string, cap int) int {
|
||||
if cap <= 0 {
|
||||
return cap
|
||||
}
|
||||
if v, ok := getRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName)); ok {
|
||||
if v > cap {
|
||||
return cap
|
||||
}
|
||||
return v
|
||||
}
|
||||
return cap
|
||||
}
|
||||
|
||||
// GetRuntimeQueueLimit 返回运行时队列上限(<= cap)。若不存在运行时值,则返回 cap。
|
||||
func GetRuntimeQueueLimit(ctx context.Context, modelName string, cap int) int {
|
||||
if cap <= 0 {
|
||||
return cap
|
||||
}
|
||||
if v, ok := getRuntimeInt(ctx, runtimeQueueLimitKey(modelName)); ok {
|
||||
if v > cap {
|
||||
return cap
|
||||
}
|
||||
return v
|
||||
}
|
||||
return cap
|
||||
}
|
||||
|
||||
func clampInt(v, minV, maxV int) int {
|
||||
if v < minV {
|
||||
return minV
|
||||
}
|
||||
if v > maxV {
|
||||
return maxV
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
56
service/semaphore.go
Normal file
56
service/semaphore.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var acquireLua = `
|
||||
local current = tonumber(redis.call("GET", KEYS[1]) or "0")
|
||||
local max = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
if current >= max then
|
||||
return 0
|
||||
end
|
||||
current = redis.call("INCR", KEYS[1])
|
||||
if current == 1 then
|
||||
redis.call("EXPIRE", KEYS[1], ttl)
|
||||
end
|
||||
if current > max then
|
||||
redis.call("DECR", KEYS[1])
|
||||
return 0
|
||||
end
|
||||
return 1
|
||||
`
|
||||
|
||||
var releaseLua = `
|
||||
local current = tonumber(redis.call("DECR", KEYS[1]) or "0")
|
||||
if current <= 0 then
|
||||
redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`
|
||||
|
||||
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
if max <= 0 {
|
||||
// 不限制
|
||||
return true, nil
|
||||
}
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 3600
|
||||
}
|
||||
r, err := g.Redis().Do(ctx, "EVAL", acquireLua, 1, key, max, ttlSeconds)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取并发令牌失败: %w", err)
|
||||
}
|
||||
return gconv.Int(r) == 1, nil
|
||||
}
|
||||
|
||||
func releaseSemaphore(ctx context.Context, key string) error {
|
||||
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
||||
return err
|
||||
}
|
||||
|
||||
40
service/stat_service.go
Normal file
40
service/stat_service.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
)
|
||||
|
||||
type statService struct{}
|
||||
|
||||
var Stat = &statService{}
|
||||
|
||||
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
}
|
||||
}
|
||||
startDay, endDay := "", ""
|
||||
var tenantID *int64
|
||||
creator, modelName := "", ""
|
||||
if req != nil {
|
||||
startDay = req.StartDay
|
||||
endDay = req.EndDay
|
||||
tenantID = req.TenantID
|
||||
creator = req.Creator
|
||||
modelName = req.ModelName
|
||||
}
|
||||
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListModelStatRes{List: list, Total: total}, nil
|
||||
}
|
||||
|
||||
18
service/storage.go
Normal file
18
service/storage.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
// StorageService 结果存储(OSS/MinIO)抽象
|
||||
type StorageService interface {
|
||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
||||
}
|
||||
|
||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
||||
var Storage StorageService = &ossStorage{}
|
||||
|
||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
||||
82
service/storage_oss.go
Normal file
82
service/storage_oss.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
type ossStorage struct{}
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := forwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fmt.Println("打印结果 resp:", resp)
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
192
service/task_service.go
Normal file
192
service/task_service.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil || m.Enabled != 1 {
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
taskID := uuid.NewString()
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
}
|
||||
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
ModelKey: req.ModelKey,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, t)
|
||||
if err != nil {
|
||||
// 入库失败:回滚闸门占位
|
||||
ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 写操作日志(尽量不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = r.GetClientIp()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
ErrorMsg: "",
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: storedPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, errors.New("任务不存在")
|
||||
}
|
||||
return &dto.GetTaskResultRes{
|
||||
OssFile: t.OssFile,
|
||||
State: t.State,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
if req == nil || len(req.TaskIDs) == 0 {
|
||||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||||
}
|
||||
// 1) 先查当前租户下的任务列表
|
||||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||||
now := time.Now()
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.State != 2 {
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retainSeconds := 86400
|
||||
if m != nil && m.AutoCleanSeconds > 0 {
|
||||
retainSeconds = m.AutoCleanSeconds
|
||||
}
|
||||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||||
|
||||
// 为了本次返回一致性,内存里也更新
|
||||
t.State = 4
|
||||
t.ExpireAt = expireAt
|
||||
}
|
||||
|
||||
// 3) 组装返回
|
||||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, dto.GetTaskBatchItem{
|
||||
TaskID: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
})
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
taskID := ""
|
||||
var state *int
|
||||
if req != nil {
|
||||
modelName = req.ModelName
|
||||
taskID = req.TaskID
|
||||
state = req.State
|
||||
}
|
||||
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListTaskRes{List: list, Total: total}, nil
|
||||
}
|
||||
38
service/tmp_store.go
Normal file
38
service/tmp_store.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func loadTmpResult(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func deleteTmpResult(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
|
||||
176
service/worker.go
Normal file
176
service/worker.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
)
|
||||
|
||||
var AsyncWorker = &asyncWorker{}
|
||||
|
||||
type asyncWorker struct {
|
||||
}
|
||||
|
||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||
// - batchSize: 本次抢占数量
|
||||
// - goroutines: 本次并发数(协程池大小)
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10
|
||||
}
|
||||
if goroutines <= 0 {
|
||||
goroutines = 1
|
||||
}
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
pool := grpool.New(goroutines)
|
||||
defer pool.Close()
|
||||
|
||||
claimed = len(tasks)
|
||||
done := make(chan struct{}, claimed)
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
||||
w.handleOne(ctx, task)
|
||||
done <- struct{}{}
|
||||
}, func(ctx context.Context, e error) {
|
||||
if e != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
|
||||
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
}
|
||||
done <- struct{}{}
|
||||
})
|
||||
}
|
||||
for i := 0; i < claimed; i++ {
|
||||
<-done
|
||||
}
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers,给 OSS 上传透传鉴权用
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
return
|
||||
}
|
||||
if m == nil || m.Enabled != 1 {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 分布式并发限制(按 model_name 全局维度)
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||
leaseSeconds := int64(3600) // 兜底1小时
|
||||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
|
||||
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
// 并发满了:放回排队(重新置回 state=0),下一轮再抢占
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = releaseSemaphore(ctx, semKey)
|
||||
}()
|
||||
|
||||
// 3) 调用模型服务
|
||||
if payload == nil {
|
||||
payload = map[string]any{
|
||||
"taskId": t.TaskID,
|
||||
"inputRef": t.InputRef,
|
||||
}
|
||||
}
|
||||
var (
|
||||
data []byte
|
||||
contentType string
|
||||
ext string
|
||||
)
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
} else {
|
||||
// 临时文件不可用:回退重新调用模型
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
// 统计:仅在真正请求模型时 +1(OSS 重试不计入)
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
||||
|
||||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
|
||||
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
||||
if err == nil && tmpPath != "" {
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 更新任务状态成功
|
||||
// 注意:expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
|
||||
fileType := strings.TrimPrefix(ext, ".")
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
}
|
||||
// 成功/失败均不再占用 queue_limit(state=0/1 才占用)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// 6) 成功回调(不影响主流程)
|
||||
t.State = 2
|
||||
t.OssFile = ossURL
|
||||
t.FileType = fileType
|
||||
go triggerSuccessCallback(context.WithoutCancel(ctx), t)
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
Reference in New Issue
Block a user