Compare commits

34 Commits
master ... dev

Author SHA1 Message Date
196d2069ac Merge remote-tracking branch 'origin/dev' into dev 2026-06-10 16:47:32 +08:00
7596cbde09 feat(task): 添加任务更新功能 2026-06-10 16:47:21 +08:00
7ec18926e3 ci/cd调整 2026-06-10 16:24:29 +08:00
a6b32bfeb3 ci/cd调整 2026-06-10 16:16:05 +08:00
2dc88ae587 refactor(prompts-core): 重构代码结构和优化工具函数 2026-06-10 14:51:24 +08:00
e906248b0a feat(session): 重构会话管理和Redis缓存机制 2026-06-09 14:00:00 +08:00
e5781aca06 Merge remote-tracking branch 'origin/dev' into dev
# Conflicts:
#	go.sum
2026-06-08 18:02:26 +08:00
0cf8948cd2 feat: 重构异步模型字段并更新依赖 2026-06-08 18:01:53 +08:00
96e8bdfe62 ci/cd调整 2026-06-08 15:37:11 +08:00
26de41d04e ci/cd调整 2026-06-08 13:44:54 +08:00
0bee3685fb ci/cd调整 2026-06-08 13:39:20 +08:00
9049e0d2e8 refactor(prompt): 重构提示词构建服务和回调处理 2026-06-05 11:00:04 +08:00
aae46a4f29 refactor(model-gateway): 重构代码结构并优化数据库查询 2026-06-03 18:37:17 +08:00
bcfcc7ed47 refactor(util): 重构映射工具函数并优化异步任务轮询逻辑 2026-06-03 13:30:39 +08:00
2c7838807b chore(deps): 初始化项目依赖配置 2026-06-02 20:28:06 +08:00
52124385a1 refactor(asynch): 重构异步模型配置和队列管理 2026-06-02 20:26:45 +08:00
c7e9eb889b feat(model): 添加流式配置支持并优化响应处理 2026-05-30 22:08:46 +08:00
qhd
558fd49ec1 fix: 修复模型查询条件为空时的异常行为 2026-05-29 18:06:50 +08:00
d409b84b58 refactor(service): 重构服务模块结构并优化模型配置 2026-05-29 17:54:19 +08:00
e487b4bb5e refactor(task): 重构异步任务处理流程 2026-05-27 09:36:25 +08:00
a28fcbaee9 feat: 新增模型扩展映射与查询配置字段 2026-05-23 18:08:08 +08:00
5416e7a983 Merge remote-tracking branch 'origin/dev' into dev 2026-05-22 13:03:29 +08:00
0e2ac286e9 feat(model): 添加运营商列表功能 2026-05-22 13:03:10 +08:00
qhd
a88dc84d99 fix: 获取聊天模型时过滤当前用户 2026-05-22 13:02:13 +08:00
qhd
4d2d4fd93d Merge branch 'dev' of http://116.204.74.41:3000/red-future/model-gateway into dev 2026-05-22 11:22:50 +08:00
qhd
7129bd2de7 fix: 会话模型查询增加用户隔离 2026-05-22 11:22:35 +08:00
09474eb997 feat(consts): 添加视频模型类型常量定义 2026-05-22 11:17:31 +08:00
4946220185 feat(prompt): 重构提示词服务并添加模型类型子分类 2026-05-22 09:49:46 +08:00
b6cdb8ff1d refactor(prompt): 优化任务等待机制并改进数据结构 2026-05-21 14:23:34 +08:00
4626d819b5 fix(gateway): 修复文件上传功能中的数据写入和头部设置问题 2026-05-21 11:18:39 +08:00
170568e03e refactor(model): 重构模型实体和数据访问层 2026-05-21 10:41:37 +08:00
a080a5536d Merge remote-tracking branch 'origin/dev' into dev 2026-05-18 19:20:18 +08:00
142fea1e91 refactor(service): 重构服务代码结构并更新配置 2026-05-18 19:19:16 +08:00
qhd
a585233c4d fix: 查询私有模型时增加enabled过滤条件 2026-05-15 16:04:18 +08:00
54 changed files with 2940 additions and 2660 deletions

View File

@@ -1,43 +1,23 @@
# 阶段构建 - 第一阶段:编译(使用已安装的镜像)
FROM golang:1.26-alpine3.23 AS builder
# 阶段1: 构建
FROM golang:alpine AS builder
RUN apk add --no-cache git ca-certificates tzdata
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
ENV GO111MODULE=on
ENV GOPROXY=https://goproxy.cn,direct
ENV CGO_ENABLED=0
ENV GOTOOLCHAIN=auto
ENV GOPRIVATE=gitea.com/red-future/common
# 配置git使用私有Gitea仓库带Token认证
RUN git config --global url."http://x-token-auth:619679cd366aefea3a50f0622d842a41f2209e08595767bba49c3836ef57d415@116.204.74.41:3000/red-future/common.git".insteadOf "https://gitea.com/red-future/common.git" && \
git config --global credential.helper store
WORKDIR /build
# 复制父目录的 common 模块(因为 go.mod 中使用了本地 replace)
#COPY ../common /build/common
COPY . .
RUN go mod download && go mod tidy
RUN go build -ldflags="-s -w" -o main ./main.go
# 第二阶段:运行
FROM alpine:3.23
ENV TIME_ZONE=Asia/Shanghai
RUN apk add --no-cache ca-certificates tzdata && \
ln -sf /usr/share/zoneinfo/$TIME_ZONE /etc/localtime
WORKDIR /app
# 复制编译好的二进制文件
COPY --from=builder /build/main .
COPY --from=builder /build/config.yml ./
# 创建日志目录
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
EXPOSE 3004

10
common/util/convert.go Normal file
View File

@@ -0,0 +1,10 @@
package util
import "github.com/gogf/gf/v2/util/gconv"
// ConvertTo 转换为指定类型
func ConvertTo[T any](v interface{}) *T {
var t T
_ = gconv.Struct(v, &t)
return &t
}

115
common/util/files.go Normal file
View File

@@ -0,0 +1,115 @@
package util
import (
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
)
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
// gateway.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx])
}
switch ct {
case "audio/mpeg":
return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav"
case "video/mp4":
return ct, ".mp4"
case "image/png":
return ct, ".png"
case "image/jpeg":
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
case "text/plain":
return ct, ".txt"
case "application/json":
return ct, ".json"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx])
}
return ct, "." + sub
}
return ct, ""
}
}
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
// SaveTempFileByType
// 根据传入的数据自动判断:
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
// 若是任意结构体/map → 自动转 JSON 保存
// 返回:新临时文件路径、错误
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
// 1. 先清理旧临时文件(统一逻辑)
if oldTmpFile != "" {
_ = os.Remove(oldTmpFile)
}
var tmpPath string
var tmpErr error
// 2. 判断是否是二进制音频([]byte + .mp3
if audioData, ok := data.([]byte); ok {
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
} else {
// 3. 其他类型 → 序列化为 JSON 保存
mappedBytes, err := json.Marshal(data)
if err != nil {
return "", err
}
if len(mappedBytes) == 0 {
return "", nil
}
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
}
if tmpErr != nil || tmpPath == "" {
return "", tmpErr
}
return tmpPath, nil
}
// saveTmpResult 你原有的底层保存文件方法(保留不动)
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
// 你原来实现,比如:
filename := taskID + ext
tmpPath := filepath.Join(os.TempDir(), filename)
err := os.WriteFile(tmpPath, data, 0644)
return tmpPath, err
}

79
common/util/headers.go Normal file
View File

@@ -0,0 +1,79 @@
package util
import (
"context"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
func AsyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
func ForwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
SetHeaderFromContext(headers, ctx, "Authorization", "token")
SetHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
FallbackToRequestHeaders(headers, ctx)
return headers
}
// SetHeaderFromContext 从上下文中设置 header
func SetHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
headers[headerKey] = value
}
}
// FallbackToRequestHeaders 从请求头中获取作为兜底
func FallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
r := g.RequestFromCtx(ctx)
if r == nil {
return
}
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
// SetTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx给 worker 调 OSS 用
func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
if v := gconv.String(headers["Authorization"]); v != "" {
ctx = context.WithValue(ctx, "token", v)
}
if v := gconv.String(headers["X-User-Info"]); v != "" {
ctx = context.WithValue(ctx, "xUserInfo", v)
}
return ctx
}

359
common/util/mapping.go Normal file
View File

@@ -0,0 +1,359 @@
package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
tgjson "github.com/tidwall/gjson"
)
// ParseAndValidate 解析并校验结果
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody]
if !ok {
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
}
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
}
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err)
}
if len(arr) == 0 {
return raw, fmt.Errorf("解析后数组为空")
}
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
}
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
}
// ParseStructResult 解析结构结果
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
contentVal := raw[responseBody]
// 是字符串,尝试解析
contentStr := gconv.String(contentVal)
if contentStr == "" || contentStr == "0" {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: raw}},
}
}
// 尝试解析为数组
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: arr}},
}
}
// 尝试解析为单个对象
var parsed any
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: parsed}},
}
}
// 兜底:原始字符串作为内容
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: contentStr}},
}
}
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取 rounds
roundsRaw, ok := raw["rounds"]
if !ok {
return fmt.Errorf("缺少 rounds 字段")
}
rounds, ok := roundsRaw.([]any)
if !ok {
return fmt.Errorf("rounds 不是数组")
}
if len(rounds) == 0 {
return fmt.Errorf("rounds 数组为空")
}
// 2) 没有配置必填字段,跳过
if len(model.RequiredFields) == 0 {
return nil
}
// 3) 逐条校验
for i, r := range rounds {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
return nil
}
// validateRequiredFields 校验单个 round 对象的必选字段
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
for _, field := range requiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
}
}
return nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
if len(mapping) == 0 {
return result, nil
}
// 把 result 转成 JSON 字符串tidwall/gjson 需要字符串输入
resultBytes, _ := json.Marshal(result)
resultStr := string(resultBytes)
mapped := make(map[string]any)
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
value := tgjson.Get(resultStr, path)
if !value.Exists() {
continue
}
// 如果是数组路径(含 #),取 Array否则取单值
if strings.Contains(path, "#") {
var arr []any
for _, v := range value.Array() {
arr = append(arr, v.Value())
}
mapped[standardField] = arr
} else {
mapped[standardField] = value.Value()
}
}
return mapped, nil
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}
// PullTaskResult 轮询查询异步任务结果直到完成
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
// 1) 解析配置
// 1.1 提取 taskID
taskIDPath := gconv.String(queryConfig["task_id"])
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
if taskID == "" {
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
}
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
// 1.2 请求地址,替换 {id}
queryUrl := gconv.String(queryConfig["url"])
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
// 1.3 请求方式
method := gconv.String(queryConfig["method"])
if method == "" {
method = "GET"
}
// 1.4 状态判断配置
statusPath := gconv.String(queryConfig["status_path"])
statusValues, _ := queryConfig["status_values"].(map[string]any)
if statusPath == "" {
statusPath = "status"
}
// 1.5 轮询间隔
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
// 1.6 请求体
reqBodyMap := map[string]any{"task_id": taskID}
// 2) 轮询请求
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var reqBody io.Reader
if method == "POST" {
bs, _ := json.Marshal(reqBodyMap)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 统一用 headMsg 注入请求头
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
req.Header.Set(hk, hv)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
raw, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
_ = json.Unmarshal(raw, &result)
statusVal := gjson.New(result).Get(statusPath).Val()
statusStr := gconv.String(statusVal)
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
if matchStatus(statusStr, statusValues["succeeded"]) {
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
return result, nil
}
if matchStatus(statusStr, statusValues["failed"]) {
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
return result, fmt.Errorf("任务失败")
}
time.Sleep(time.Duration(interval) * time.Second)
}
}
func matchStatus(actual string, expected any) bool {
expectedStr := gconv.String(expected)
if actual == expectedStr {
return true
}
switch v := expected.(type) {
case []any:
for _, item := range v {
if actual == gconv.String(item) {
return true
}
}
}
return false
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
return gconv.String(val)
}
return s
})
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
return payload
}

150
common/util/streaming.go Normal file
View File

@@ -0,0 +1,150 @@
package util
import (
"encoding/base64"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
)
// ================================================================
// ParseStreamResponse 流式响应解析(通用入口)
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
enabled, _ := streamConfig["enabled"].(bool)
if !enabled {
return gjson.New(string(rawBytes)).Map(), nil
}
parser, _ := streamConfig["parser"].(string)
if parser == "base64_concat" {
return parseBase64Stream(rawBytes)
}
return parseSSEStream(rawBytes, streamConfig)
}
// parseBase64Stream 拼接流式 base64 并解码为二进制TTS 等音频模型)
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
lines := strings.Split(string(rawBytes), "\n")
var audioBase64 strings.Builder
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
if data, ok := chunk["data"].(string); ok && data != "" {
audioBase64.WriteString(data)
}
}
cleanBase64 := strings.Map(func(r rune) rune {
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
return -1
}
return r
}, audioBase64.String())
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
if err != nil {
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
if err != nil {
return nil, fmt.Errorf("base64 解码失败: %w", err)
}
}
return map[string]any{"audio": audioBytes}, nil
}
// parseSSEStream SSE 流式解析(图片模型等)
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
events, _ := streamConfig["events"].([]any)
if len(events) == 0 {
return gjson.New(string(rawBytes)).Map(), nil
}
lines := strings.Split(string(rawBytes), "\n")
result := make(map[string]any)
var partials []map[string]any
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || line == "[DONE]" {
continue
}
if strings.HasPrefix(line, "event:") {
continue
}
if strings.HasPrefix(line, "data:") {
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
chunkType, _ := chunk["type"].(string)
for _, evt := range events {
e, _ := evt.(map[string]any)
match, _ := e["match"].(string)
if !strings.Contains(chunkType, match) {
continue
}
fields, _ := e["fields"].(map[string]any)
aggregateTo, _ := e["aggregate_to"].(string)
evtType, _ := e["type"].(string)
switch evtType {
case "partial":
item := make(map[string]any)
for localKey, chunkKey := range fields {
item[localKey] = chunk[chunkKey.(string)]
}
partials = append(partials, item)
case "final":
for localKey, chunkKey := range fields {
val := gjson.New(chunk).Get(chunkKey.(string))
if !val.IsNil() {
if _, exists := result[aggregateTo]; !exists {
result[aggregateTo] = make(map[string]any)
}
result[aggregateTo].(map[string]any)[localKey] = val.Val()
}
}
}
}
}
if len(partials) > 0 {
for _, evt := range events {
e, _ := evt.(map[string]any)
if e["type"] == "partial" {
if orderBy, ok := e["order_by"].(string); ok {
sort.Slice(partials, func(i, j int) bool {
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
})
}
result[e["aggregate_to"].(string)] = partials
break
}
}
}
mergedBytes, _ := json.Marshal(result)
return gjson.New(mergedBytes).Map(), nil
}

View File

@@ -26,20 +26,45 @@ database:
updatedAt: "updated_at" # (可选)自动更新时间字段名称
deletedAt: "deleted_at" # (可选)软删除时间字段名称
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
model_gateway:
- type: "pgsql"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
name: "model-gateway"
prefix: ""
role: "master"
debug: true
dryRun: false
charset: "utf8"
timezone: "Asia/Shanghai"
maxIdle: 5
maxOpen: 20
maxLifetime: "30s"
maxIdleConnTime: "30s"
createdAt: "created_at"
updatedAt: "updated_at"
deletedAt: "deleted_at"
timeMaintainDisabled: false
redis:
default:
address: 116.204.74.41:6379
address: 192.168.3.30:6379
db: 0
consul:
address: 116.204.74.41:8500
address: 192.168.3.30:8500
jaeger:
addr: 116.204.74.41:4318
addr: 192.168.3.30:4318
# 本地调试用:可选自动执行 worker/cleaner默认关闭
asynch:
queryPending:
enabled: false
intervalSeconds: 10 # 每10秒轮询一次
limit: 10 # 每次查10条
worker:
enabled: false
intervalSeconds: 5
@@ -48,11 +73,3 @@ asynch:
cleaner:
enabled: false
intervalSeconds: 30
modelType:
types:
1: "推理模型"
2: "图片模型"
3: "音频模型"
4: "向量化模型"
5: "全模态模型"

115
consts/public/public.go Normal file
View File

@@ -0,0 +1,115 @@
package public
const (
CallModeSync = 0 // 同步调用
CallModeAsync = 1 // 异步调用
CallModeStream = 2 // 流式调用
)
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建
BuildTypeStruct = 3 //结构构建
)
// ModelType 模型类型常量
const (
ModelTypeInference = 100 // 推理模型
ModelTypeImage = 200 // 图片模型
ImageSubTypeTextToImage = 201 // 图片模型-文生图
ImageSubTypeImageToImage = 202 // 图片模型-图生图
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图
ModelTypeAudio = 300 // 音频模型
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
AudioSubTypeSpeechToText = 302 // 音频模型-音生文
AudioSubTypeSpeechToSpeech = 303 // 音频模型-音生音
ModelTypeVector = 400 // 向量化模型
VectorSubTypeEmbedding = 401 // 向量化模型-文本嵌入
VectorSubTypeRerank = 402 // 向量化模型-重排序
ModelTypeOmni = 500 // 全模态模型
OmniSubTypeTextImageAudio = 501 // 全模态模型-文图音
OmniSubTypeVision = 502 // 全模态模型-视觉理解
ModelTypeVideo = 600 // 视频模型
VideoSubTypeTextToVideo = 601 // 视频模型-文生视频
VideoSubTypeImageToVideo = 602 // 视频模型-图生视频
VideoSubTypeImageTextToVideo = 603 // 视频模型-图文生视频
VideoSubTypeVideoToVideo = 604 // 视频模型-视频生视频
)
// ModelTypeName 模型类型名称映射
var ModelTypeName = map[int]string{
ModelTypeInference: "推理模型",
ModelTypeImage: "图片模型",
ImageSubTypeTextToImage: "图片模型-文生图",
ImageSubTypeImageToImage: "图片模型-图生图",
ImageSubTypeImageEdit: "图片模型-图片编辑",
ImageSubTypeImageVariation: "图片模型-图片变体",
ImageSubTypeImageTextToImage: "图片模型-图文生图",
ModelTypeAudio: "音频模型",
AudioSubTypeTextToSpeech: "音频模型-文生音",
AudioSubTypeSpeechToText: "音频模型-音生文",
AudioSubTypeSpeechToSpeech: "音频模型-音生音",
ModelTypeVector: "向量化模型",
VectorSubTypeEmbedding: "向量化模型-文本嵌入",
VectorSubTypeRerank: "向量化模型-重排序",
ModelTypeOmni: "全模态模型",
OmniSubTypeTextImageAudio: "全模态模型-文图音",
OmniSubTypeVision: "全模态模型-视觉理解",
ModelTypeVideo: "视频模型",
VideoSubTypeTextToVideo: "视频模型-文生视频",
VideoSubTypeImageToVideo: "视频模型-图生视频",
VideoSubTypeImageTextToVideo: "视频模型-图文生视频",
VideoSubTypeVideoToVideo: "视频模型-视频生视频",
}
// 运营商常量
const (
OperatorAliyun = "阿里云百炼"
OperatorVolcengine = "火山引擎"
OperatorTencent = "腾讯云"
OperatorHuawei = "华为云"
OperatorBaidu = "百度智能云"
OperatorOpenAI = "OpenAI"
OperatorAzure = "Azure OpenAI"
OperatorAWS = "AWS Bedrock"
OperatorGoogle = "Google Cloud"
OperatorDeepSeek = "DeepSeek"
OperatorMoonshot = "Moonshot"
OperatorZhipu = "智谱AI"
OperatorBaichuan = "百川智能"
OperatorMinimax = "MiniMax"
OperatorXunfei = "科大讯飞"
OperatorOthers = "其他"
)
// OperatorList 运营商列表(供前端下拉框使用)
var OperatorList = []string{
OperatorAliyun,
OperatorVolcengine,
OperatorTencent,
OperatorHuawei,
OperatorBaidu,
OperatorOpenAI,
OperatorAzure,
OperatorAWS,
OperatorGoogle,
OperatorDeepSeek,
OperatorMoonshot,
OperatorZhipu,
OperatorBaichuan,
OperatorMinimax,
OperatorXunfei,
OperatorOthers,
}

View File

@@ -1,5 +1,9 @@
package public
const (
DbNameModelGateway = "model_gateway" //数据库名称
)
const (
TableNameModel = "asynch_models" // 模型表
TableNameTask = "asynch_task" // 任务表

View File

@@ -1 +0,0 @@
package controller

View File

@@ -2,12 +2,9 @@ package controller
import (
"context"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service"
"gitea.com/red-future/common/beans"
modelService "model-gateway/service/model"
"model-gateway/service/queue"
)
type model struct{}
@@ -17,71 +14,53 @@ var Model = new(model)
// CreateModel 添加配置
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
return service.Model.Create(ctx, req)
return modelService.Model.Create(ctx, req)
}
// UpdateModel 更改配置
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) {
err = service.Model.Update(ctx, req)
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
err = modelService.Model.Update(ctx, req)
return
}
// DeleteModel 删除配置
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) {
err = service.Model.Delete(ctx, req.ID)
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
err = modelService.Model.Delete(ctx, req)
return
}
// GetModel 获取配置详情(按 modelName
// GetModel 获取配置详情
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
model, err := service.Model.Get(ctx, req.ID)
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
return &dto.GetModelRes{Model: model}, nil
return modelService.Model.Get(ctx, req)
}
// ListModel 配置列表
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
list, total, err := service.Model.List(ctx, req)
if err != nil {
return nil, err
}
return &dto.ListModelRes{
List: list,
Total: total,
}, nil
return modelService.Model.List(ctx, req)
}
// AutoTune 动态调参(由上层定时任务每小时触发一次)
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
windowSeconds := 3600
if req != nil && req.WindowSeconds > 0 {
windowSeconds = req.WindowSeconds
}
list, err := service.AutoTune(ctx, windowSeconds)
if err != nil {
return nil, err
}
return &dto.AutoTuneRes{List: list}, nil
return queue.AutoTune(ctx, req)
}
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res dto.TypeItem, err error) {
modelType := service.GetModelTypesFromConfig(ctx)
res.Type = modelType
return res, nil
// ListType 模型类型列表
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
return modelService.GetModelTypesFromConfig()
}
// ListOperator 运营商列表
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
return modelService.GetOperatorList()
}
// UpdateChatModel 更新是否为聊天模型
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *beans.ResponseEmpty, err error) {
err = service.Model.UpdateChatModel(ctx, req)
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
err = modelService.Model.UpdateChatModel(ctx, req)
return
}
// GetIsChatModel 获取是否为聊天模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *entity.AsynchModel, err error) {
return service.Model.GetIsChatModel(ctx)
// GetIsChatModel 获取当前会话模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
return modelService.Model.GetIsChatModel(ctx)
}

View File

@@ -2,9 +2,9 @@ package controller
import (
"context"
statService "model-gateway/service/stat"
"model-gateway/model/dto"
"model-gateway/service"
)
type stat struct{}
@@ -14,5 +14,5 @@ var Stat = new(stat)
// ListModelStat 统计列表
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
return service.Stat.List(ctx, req)
return statService.Stat.List(ctx, req)
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"context"
"model-gateway/service/job"
taskService "model-gateway/service/task"
"model-gateway/model/dto"
"model-gateway/service"
)
type task struct{}
@@ -14,44 +15,35 @@ var Task = new(task)
// CreateTask 根据 modelName 创建异步任务,返回 taskId
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
return service.Task.Create(ctx, req)
return taskService.Task.Create(ctx, req)
}
// ModelTaskCallback 接收模型异步任务的回调通知
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return taskService.Task.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.Task.QueryPendingTasks(ctx, req)
}
// GetTaskResult 获取任务结果(只返回 oss 地址 + state
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
return service.Task.GetResult(ctx, req.TaskID)
return taskService.Task.GetResult(ctx, req.TaskID)
}
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
return service.Task.GetBatch(ctx, req)
return taskService.Task.GetBatch(ctx, req)
}
// ListTask 任务列表分页查询
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
return service.Task.List(ctx, req)
}
// RunWork 手动触发一次 worker由上层定时任务调用
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
batchSize, goroutines := 10, 1
if req != nil {
if req.BatchSize > 0 {
batchSize = req.BatchSize
}
if req.Goroutines > 0 {
goroutines = req.Goroutines
}
}
n, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines)
if err != nil {
return nil, err
}
return &dto.RunWorkRes{Claimed: n}, nil
return taskService.Task.List(ctx, req)
}
// CleanWork 手动触发一次 cleaner由上层定时任务调用
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
service.Cleaner.RunOnce(ctx)
return &dto.CleanWorkRes{Ok: true}, nil
return job.Cleaner.RunOnce(ctx)
}

View File

@@ -2,14 +2,12 @@ package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
"strconv"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -18,159 +16,148 @@ var Model = &modelDao{}
type modelDao struct{}
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
asyncModel := new(entity.AsynchModel)
err = gconv.Struct(req, &asyncModel)
// Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
m := new(entity.AsynchModel)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return 0, err
return
}
return r.LastInsertId()
}
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
// 触发 gfdb 的 updateHook 自动填充 updater需要显式带 updater 字段
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
// Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, m.ID).
Data(m).
Data(&req).
Where(entity.AsynchModelCol.Id, req.Id).
Update()
if err != nil {
return 0, err
return
}
return r.RowsAffected()
}
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.Id, id).
// Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Delete()
if err != nil {
return 0, err
return
}
return r.RowsAffected()
}
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.ModelName, modelName).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
Where(entity.AsynchModelCol.Id, id).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
// Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.Id, req.ID).Count()
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt)
if modelNameLike != "" {
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
}
if modelType != 0 {
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
}
if isPrivate != 0 {
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
}
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
//// Get 按ID获取带租户隔离只查当前租户
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// var whereCondition strings.Builder
// var queryParams []interface{}
// if !g.IsEmpty(req.Id) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
// queryParams = append(queryParams, req.Id)
// }
// if !g.IsEmpty(req.Creator) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
// queryParams = append(queryParams, req.Creator)
// }
// if !g.IsEmpty(req.IsChatModel) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
// queryParams = append(queryParams, req.IsChatModel)
// }
// if !g.IsEmpty(req.ModelName) {
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
// queryParams = append(queryParams, req.ModelName)
// }
// // 完整 SQL
// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
// if err != nil {
// return
// }
// var i []*entity.AsynchModel
// if err = r.Structs(&i); err != nil {
// return nil, err
// }
// for _, item := range i {
// m = item
// }
// return
//}
// GetByAcrossTenant 按ID获取跨租户查所有租户
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return nil, 0, err
return
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
err = r.Struct(&m)
return
}
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
// 构建 Where 条件
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
args := []any{creator}
if modelNameLike != "" {
whereSQL += " AND model_name LIKE ?"
args = append(args, "%"+modelNameLike+"%")
}
// 查总数
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
if err != nil {
return nil, 0, err
}
if len(countResult) > 0 {
total = gconv.Int64(countResult[0]["count"])
}
// 查列表
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
if pageNum > 0 && pageSize > 0 {
offset := (pageNum - 1) * pageSize
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
}
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
if err != nil {
return nil, 0, err
}
err = r.Structs(&list)
return
}
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
// 基础 SQL
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
AND (? = 0 OR model_type = ?)
`
args := []any{
req.ModelName, "%" + req.ModelName + "%",
req.ModelType, req.ModelType,
}
// modelType: 传 6 模糊匹配 6%
if req.ModelType > 0 {
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
sql += ` AND model_type::text LIKE ? `
args = append(args, prefix+"%")
}
if !g.IsEmpty(req.IsPrivate) {
sql += ` AND is_private = ? `
args = append(args, req.IsPrivate)
}
if req.IsOwner != nil && *req.IsOwner == 0 {
sql += ` AND creator = ? AND is_owner = ? `
args = append(args, req.Creator)
args = append(args, req.IsOwner)
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
} else if req.Enabled != nil && *req.Enabled == 0 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=0 `
} else {
sql += ` AND creator = ? AND is_owner = ? `
}
args = append(args, req.Creator, req.IsOwner)
} else if req.IsOwner != nil && *req.IsOwner == 1 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
@@ -179,14 +166,12 @@ WHERE deleted_at IS NULL
} else {
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
}
args = append(args, req.Creator)
args = append(args, req.IsOwner)
args = append(args, req.Creator, req.IsOwner)
}
// 最后拼接排序
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
if err != nil {
return nil, 0, err
}
@@ -200,33 +185,24 @@ WHERE deleted_at IS NULL
return
}
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where(entity.AsynchModelCol.IsChatModel, 1).
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
One()
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
err = r.Struct(&m)
return
}
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
OrderDesc(entity.AsynchModelCol.CreatedAt).
All()
if err != nil {
var list []*entity.AsynchModel
if err := r.Structs(&list); err != nil {
return nil, err
}
err = r.Structs(&list)
return
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}

View File

@@ -1,32 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
)
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
r, err := gfdb.DB(ctx).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
)
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, nil
}
var list []*entity.AsynchModel
if err := r.Structs(&list); err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}

View File

@@ -6,15 +6,23 @@ import (
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
type opLogDao struct{}
var OpLog = &opLogDao{}
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
// Insert 插入
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
m := new(entity.LogsModelOp)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
Insert(m)
if err != nil {
return 0, err
}

View File

@@ -8,7 +8,7 @@ import (
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
)
@@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name)
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
public.TableNameStat, public.TableNameStat,
)
_, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
return err
}

View File

@@ -2,13 +2,10 @@ package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
@@ -18,40 +15,60 @@ var Task = &taskDao{}
type taskDao struct{}
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return 0, err
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
One()
// Update 更新
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(&req).
Where(entity.AsynchTaskCol.Id, req.Id).
Update()
if err != nil {
return nil, err
return
}
if r.IsEmpty() {
return nil, nil
return r.RowsAffected()
}
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&t)
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
err = r.Structs(&m)
return
}
@@ -62,7 +79,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
@@ -70,76 +87,9 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
return err
}
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 1,
entity.AsynchTaskCol.StartedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 2,
entity.AsynchTaskCol.OssFile: ossFile,
entity.AsynchTaskCol.FileType: fileType,
entity.AsynchTaskCol.FileSize: fileSize,
entity.AsynchTaskCol.ErrorMsg: "",
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
data := gdb.Map{
entity.AsynchTaskCol.State: 3,
entity.AsynchTaskCol.ErrorMsg: errorMsg,
entity.AsynchTaskCol.FinishedAt: now,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.TaskID, taskID).
Delete()
if err != nil {
return 0, err
}
return r.RowsAffected()
}
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.ModelName, modelName).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Count()
return int64(n), err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
if modelNameLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
}
@@ -162,89 +112,13 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike
return
}
// ClaimPending 抢占 pending 任务state=0并在同一事务中更新为 runningstate=1
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
FROM %s
WHERE deleted_at IS NULL AND state = 0
ORDER BY created_at ASC
LIMIT %d
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
batchSize,
)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
tasks = nil
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
// 更新为 running
now := time.Now()
for _, t := range tasks {
// tx.Model 不走 gfdb Hook这里手动更新必要字段
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
if err != nil {
return err
}
}
return nil
})
return
}
// ListExpiredSuccess 获取已成功且过期的任务
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.State, 2).
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where("state", 1).
Where("deleted_at IS NULL").
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// ListTimeoutTasks 获取超时的排队/执行中任务
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 100
}
deadline := gtime.New(time.Now().Add(-timeout))
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
Limit(limit).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// DebugPing 用于启动时检测数据库连通性(可选)
func (d *taskDao) DebugPing(ctx context.Context) error {
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
return err
Scan(&tasks)
return tasks, err
}

View File

@@ -8,33 +8,58 @@ import (
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
// ======================== 查询辅助 ========================
// taskColumns 查询用的公共字段
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 事务抢占 ========================
// claimTasks 事务内抢占任务并更新 state=1
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
if batchSize <= 0 {
batchSize = 1
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
FROM %s
WHERE deleted_at IS NULL AND state = 0
ORDER BY enqueue_at ASC
LIMIT %d
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
batchSize,
)
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
tasks = nil
return nil
}
if err := r.Structs(&tasks); err != nil {
@@ -42,245 +67,148 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, t.Id,
)
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
if err != nil {
return err
}
}
return nil
})
return
return tasks, err
}
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) {
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
if taskID == "" {
return nil, nil
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
FROM %s
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
LIMIT 1
FOR UPDATE SKIP LOCKED`,
public.TableNameTask,
)
r, err := tx.GetOne(sql, taskID)
if err != nil {
return err
}
if r.IsEmpty() {
task = nil
return nil
}
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
now, now, task.Id,
)
return err
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
}
// ======================== 更新辅助 ========================
func execSQL(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
return err
}
// updateTask 通用更新
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
return err
}
// UpdateSuccessGlobal 更新任务成功
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 2,
OssFile: t.OssFile,
FileType: t.FileType,
TextResult: t.TextResult,
FileSize: t.FileSize,
ErrorMsg: "",
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
ExpendTokens: t.ExpendTokens,
DurationSeconds: t.DurationSeconds,
})
return
}
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType, textResult string, fileSize int64, expireAt *gtime.Time, expendTokens int) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s
SET state=2,
oss_file=?,
file_type=?,
text_result=?,
expend_tokens=?,
file_size=?,
error_msg='',
finished_at=?,
duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT,
expire_at=NULL,
phase=0,
tmp_file='',
updated_at=?
WHERE id=?`, public.TableNameTask),
ossFile, fileType, textResult, expendTokens, fileSize, now, now, now, id,
)
return err
// UpdateFailedGlobal 模型调用失败
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 3,
ErrorMsg: t.ErrorMsg,
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
TextResult: t.TextResult,
DurationSeconds: t.DurationSeconds,
})
}
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s
SET state=3,
error_msg=?,
finished_at=?,
duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT,
phase=0,
tmp_file='',
updated_at=?
WHERE id=?`, public.TableNameTask),
errorMsg, now, now, now, id,
)
return err
}
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file下一轮仅重试 OSS 上传
// UpdateFailedKeepTmpGlobal OSS 上传失败
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
now := gtime.Now()
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask),
errorMsg, now, now, id,
)
return err
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
}
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
// UpdateTmpAfterModelGlobal 写临时文件
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask),
tmpFile, id,
)
return err
}
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
taskID,
)
return err
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
}
// RollbackToPendingGlobal 回滚
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
id,
)
return err
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
}
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
gtime.Now(), limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
// IncRetryCountGlobal 重试计数+1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
}
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
// retry_count 不含首次执行retry_times 表示失败后最多再重试 N 次
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*,
m.retry_queue_max_seconds AS retry_queue_max_seconds
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state = 3
AND t.retry_count < m.retry_times
ORDER BY t.updated_at ASC
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// RequeueForRetryGlobal 将任务重新入队state=0并将 retry_count +1
// enqueueAt 用于控制重试任务在队列中的位置:
// - enqueueAt 越早越靠前ClaimPendingGlobal 按 enqueue_at ASC 抢占)
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
enqueueAt, id,
)
return err
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
}
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state = 3
AND t.retry_count >= m.retry_times
ORDER BY t.updated_at ASC
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
}
// HardDeleteByIDGlobal 硬删除任务记录
// ListFailedRetryableGlobal
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListFailedExhaustedGlobal
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListTimeoutTasksGlobal
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// HardDeleteByIDGlobal
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx).Exec(ctx,
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
id,
)
return err
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
}
// ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务:
// - state in (0,1)
// - 模型 expected_seconds > 0
// - now - created_at >= expected_seconds
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
if limit <= 0 {
limit = 200
}
r, err := gfdb.DB(ctx).GetAll(ctx,
fmt.Sprintf(`
SELECT t.*
FROM %s t
JOIN %s m
ON t.tenant_id = m.tenant_id
AND t.model_name = m.model_name
WHERE t.deleted_at IS NULL
AND t.state IN (0,1)
AND m.expected_seconds > 0
AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval)
LIMIT ?`, public.TableNameTask, public.TableNameModel),
limit,
)
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
err = r.Structs(&list)
return
return list, err
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}
// UpdateColumns 更新指定字段(结构体版)
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}

34
go.mod
View File

@@ -1,22 +1,14 @@
module model-gateway
go 1.26.0
go 1.26.1
require (
gitea.com/red-future/common v0.0.19
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
github.com/gogf/gf/v2 v2.10.0
gitea.redpowerfuture.com/red-future/common v0.0.23
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
github.com/gogf/gf/v2 v2.10.2
github.com/google/uuid v1.6.0
github.com/tidwall/gjson v1.14.2
)
require (
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
google.golang.org/appengine v1.6.7 // indirect
github.com/tidwall/gjson v1.19.0
)
require (
@@ -57,7 +49,7 @@ require (
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
github.com/hashicorp/serf v0.10.1 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
@@ -69,11 +61,14 @@ require (
github.com/olekukonko/ll v0.0.9 // indirect
github.com/olekukonko/tablewriter v1.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/redis/go-redis/v9 v9.12.1 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/tidwall/sjson v1.2.5
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tiger1103/gfast-token v1.0.10 // indirect
github.com/vcaesar/cedar v0.30.0 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
go.mongodb.org/mongo-driver/v2 v2.4.0 // indirect
go.opencensus.io v0.23.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
@@ -85,9 +80,10 @@ require (
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect

52
go.sum
View File

@@ -1,6 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -77,16 +77,16 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 h1:39+jbTenm7KBj4hO2C8ANAxVHpX/7OuRDs1VcGC9ylA=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0/go.mod h1:B0s0fVzn0W220E8UTpSGzrrGKsop5KcB90twBeLCiz0=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0 h1:N/F9CuDdUZLoM1nVRqrDE/33pDZuhVxpNY4wYdeIaBs=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0/go.mod h1:x6uoJGfZOtirIRQls8xUlYzC6f7T/eULPUa9er368X0=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 h1:u8EpP24GkprogROnJ7htMov9Fc66pTP1eVYrWxiCYOs=
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2/go.mod h1:GmvM3r8GVByVMi4RD2+MCs5+CfxVXPMeT8mVDkAaAXE=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 h1:iTQegT+lEg/wDKvj2mi3W1wrdrwFarjokf88EXVVgu4=
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2/go.mod h1:ZRw3GNz5cq4uYrW4TPSVyrYWaoqzujKdWro/AOcGBaE=
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs=
github.com/gogf/gf/v2 v2.10.0/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
github.com/gogf/gf/v2 v2.10.2 h1:46IO0Uc8e85/FqdftJFskfDejJLBL0JBnGS5qOftUu8=
github.com/gogf/gf/v2 v2.10.2/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@@ -185,8 +185,8 @@ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -288,14 +288,12 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
@@ -344,8 +342,8 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -361,8 +359,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -371,8 +369,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -397,15 +395,15 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
@@ -415,8 +413,8 @@ golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -426,6 +424,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=

60
main.go
View File

@@ -2,17 +2,19 @@ package main
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/job"
"model-gateway/service/task"
"os"
"os/signal"
"syscall"
"time"
"model-gateway/controller"
"model-gateway/service"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
_ "gitea.com/red-future/common/swagger"
"gitea.redpowerfuture.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/jaeger"
_ "gitea.redpowerfuture.com/red-future/common/swagger"
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
"github.com/gogf/gf/v2/frame/g"
@@ -33,42 +35,18 @@ func main() {
// 本地调试:可选自动触发 worker/cleaner由配置文件控制
startAutoRunner(ctx)
// 监听退出信号,确保 Ctrl+C 能完整退出(停止 worker/cleaner 并关闭 http server
// 监听退出信号,确保 Ctrl+C 能完整退出(停止 worker/cleaner 并关闭 gateway server
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
<-quit
g.Log().Infof(ctx, "[main] 收到退出信号,开始优雅退出...")
cancel()
// 关闭 http serverRouteRegister 内部是 go Httpserver.Run() 启动的)
// 关闭 gateway serverRouteRegister 内部是 go Httpserver.Run() 启动的)
_ = http.Httpserver.Shutdown()
}
func startAutoRunner(ctx context.Context) {
// worker
if g.Cfg().MustGet(ctx, "asynch.worker.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
if interval <= 0 {
interval = 5
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize").Int()
goroutines := g.Cfg().MustGet(ctx, "asynch.worker.goroutines").Int()
ticker := time.NewTicker(time.Duration(interval) * time.Second)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if _, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines); err != nil {
g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err)
}
}
}
}()
}
// cleaner
if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int()
@@ -83,7 +61,27 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
service.Cleaner.RunOnce(ctx)
_, _ = job.Cleaner.RunOnce(ctx)
}
}
}()
}
// queryPending
if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int()
limit := g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
ticker := time.NewTicker(time.Duration(interval) * time.Second)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err)
}
}
}
}()

View File

@@ -1,36 +1,44 @@
package dto
import (
"gitea.com/red-future/common/beans"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
// CreateModelReq 添加模型配置
type CreateModelReq struct {
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 http(s)://host:port"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定支持多个逗号分隔示例Authorization:Bearer xxx,Content-Type:application/json"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化0-私有(默认) 1-公共"`
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用0-禁用,1-启用默认1"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型0-否1-是默认0"`
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者0-否1-是默认0"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
Form any `p:"form" json:"form" dc:"动态表单配置JSON用于前端渲染配置项"`
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"`
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数默认10"`
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限默认1000"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间默认600"`
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数默认3"`
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间默认600"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间默认86400"`
Remark string `p:"remark" json:"remark" dc:"备注说明"`
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数默认10"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数默认3"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间默认86400"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type CreateModelRes struct {
@@ -38,48 +46,64 @@ type CreateModelRes struct {
}
type UpdateModelReq struct {
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表逗号分隔可选更新"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST(可选更新)"`
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
Form any `p:"form" json:"form" dc:"动态表单配置JSON可选更新"`
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射可选更新"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-禁用1-启用(可选更新)"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化0-私有(默认) 1-公共"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型0-否1-是默认0"`
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者0-否1-是默认0"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type UpdateModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// DeleteModelReq 删除模型配置
type DeleteModelReq struct {
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
ID string `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
type DeleteModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// GetModelReq 获取模型配置详情
type GetModelReq struct {
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
}
type GetModelRes struct {
Model any `json:"model" dc:"模型配置详情"`
Model *entity.AsynchModel `json:"model" dc:"模型配置详情"`
}
// ListModelReq 配置列表
@@ -124,11 +148,43 @@ type TypeItem struct {
Type map[int]string `json:"type" dc:"模型类型ID到名称的映射"`
}
type ListOperatorReq struct {
g.Meta `path:"/listOperator" method:"get" tags:"模型管理" summary:"获取运营商列表" dc:"获取运营商列表"`
}
type ListOperatorRes struct {
List []string `json:"list" dc:"运营商名称到ID的映射"`
}
type UpdateChatModelReq struct {
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
}
type UpdateChatModelRes struct {
ID int64 `json:"id,string" dc:"模型ID"`
}
type GetIsChatModelReq struct {
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
}
type GetIsChatModelRes struct {
Model any `json:"model" dc:"模型详情"`
}
// NodeFormField 节点表单
type NodeFormField struct {
Value any `json:"value" dc:"字段值"`
Field string `json:"field" dc:"字段标识"`
Label string `json:"label" dc:"字段标签"`
Type string `json:"type" dc:"字段类型"`
Required bool `json:"required" dc:"是否必填"`
Default any `json:"default,omitempty" dc:"默认值"`
Options []SelectOption `json:"options" dc:"下拉选项列表"`
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
}
type SelectOption struct {
Label string `json:"label" dc:"选项标签"`
Value string `json:"value" dc:"选项值"`
}

View File

@@ -5,22 +5,55 @@ import "github.com/gogf/gf/v2/frame/g"
// CreateTaskReq 创建异步任务
type CreateTaskReq struct {
g.Meta `path:"/createTask" method:"post" tags:"任务管理" summary:"创建异步任务" dc:"创建异步任务并返回任务ID创建成功后会立即异步尝试执行当前任务执行成功后按回调配置触发钩子"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用如OSS/文件引用等)"`
RequestPayload any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用如OSS/文件引用等)"`
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"`
}
type CreateTaskRes struct {
TaskID string `json:"taskId" dc:"任务ID"`
}
type ModelTaskCallbackReq struct {
g.Meta `path:"/modelCallback" method:"post" tags:"异步任务" summary:"模型任务回调通知"`
TaskID string `json:"id" dc:"任务ID"`
Status string `json:"status" dc:"queued/running/succeeded/failed/expired"`
Content map[string]any `json:"content,omitempty" dc:"任务结果内容"`
Usage map[string]any `json:"usage,omitempty" dc:"token用量"`
}
type ModelTaskCallbackRes struct {
Success bool `json:"success" dc:"是否接收成功"`
}
// QueryPendingTasksReq 批量轮询请求
type QueryPendingTasksReq struct {
g.Meta `path:"/queryPending" method:"get" tags:"异步任务" summary:"批量轮询进行中的任务"`
Limit int `p:"limit" json:"limit" dc:"查询数量默认10"`
}
// QueryPendingTasksRes 批量轮询响应
type QueryPendingTasksRes struct {
Total int `json:"total" dc:"本次查询数量"`
Results []QueryTaskItem `json:"results" dc:"查询结果列表"`
}
// QueryTaskItem 单个任务查询结果
type QueryTaskItem struct {
TaskID string `json:"taskId" dc:"任务ID"`
Status string `json:"status" dc:"任务状态"`
Content map[string]any `json:"content,omitempty" dc:"结果内容"`
Usage map[string]any `json:"usage,omitempty" dc:"token用量"`
}
// GetTaskResultReq 获取结果(只返回 oss 地址)
type GetTaskResultReq struct {
g.Meta `path:"/getTaskResult" method:"get" tags:"任务管理" summary:"获取任务结果" dc:"根据任务ID获取结果只返回OSS地址"`
TaskID string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
TaskID string `p:"taskId" json:"taskId" v:"required#taskwId不能为空" dc:"任务ID"`
}
type GetTaskResultRes struct {

View File

@@ -1,88 +1,103 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type asynchModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
TokenMapping string
Prompt string
IsPrivate string
IsChatModel string
ApiKey string
Enabled string
MaxConcurrency string
QueueLimit string
TimeoutSeconds string
ExpectedSeconds string
RetryTimes string
RetryQueueMaxSecs string
AutoCleanSeconds string
Remark string
IsOwner string
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
RequiredFields string
IsPrivate string
IsChatModel string
CallMode string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
CallbackUrl string
}
var AsynchModelCol = asynchModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
TokenMapping: "token_mapping",
Prompt: "prompt",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
QueueLimit: "queue_limit",
TimeoutSeconds: "timeout_seconds",
ExpectedSeconds: "expected_seconds",
RetryTimes: "retry_times",
RetryQueueMaxSecs: "retry_queue_max_seconds",
AutoCleanSeconds: "auto_clean_seconds",
Remark: "remark",
IsOwner: "is_owner",
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
RequiredFields: "required_fields",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
CallMode: "call_mode",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
CallbackUrl: "callback_url",
}
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form any `orm:"form_json" json:"form"`
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
ResponseBody any `orm:"response_body" json:"responseBody"`
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
Prompt string `orm:"prompt" json:"prompt"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
Remark string `orm:"remark" json:"remark"`
IsOwner *int `json:"isOwner" orm:"is_owner"` // 1=当前用户创建的0=超级管理员的
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
CallMode *int `orm:"call_mode" json:"callMode"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
}

View File

@@ -1,26 +0,0 @@
package entity
import "gitea.com/red-future/common/beans"
type asynchModelTypeCol struct {
beans.SQLBaseCol
TypeID string
TypeName string
Remark string
}
var AsynchModelTypeCol = asynchModelTypeCol{
SQLBaseCol: beans.DefSQLBaseCol,
TypeID: "type_id",
TypeName: "type_name",
Remark: "remark",
}
// AsynchModelType 模型类型(图片/音频/视频等)
type AsynchModelType struct {
beans.SQLBaseDO `orm:",inline"`
TypeID int `orm:"type_id" json:"typeId"`
TypeName string `orm:"type_name" json:"type"`
Remark string `orm:"remark" json:"remark"`
}

View File

@@ -1,7 +1,7 @@
package entity
import (
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/os/gtime"
)
@@ -62,28 +62,28 @@ var AsynchTaskCol = asynchTaskCol{
// AsynchTask 异步任务
type AsynchTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
ModelKey string `orm:"model_key" json:"modelKey"`
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
FileSize int64 `orm:"file_size" json:"fileSize"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
RetryCount int `orm:"retry_count" json:"retryCount"`
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
InputRef string `orm:"input_ref" json:"inputRef"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
TextResult string `orm:"text_result" json:"text"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID用于标识同一轮次的任务
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
ModelKey string `orm:"model_key" json:"modelKey"`
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
FileSize int64 `orm:"file_size" json:"fileSize"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
RetryCount int `orm:"retry_count" json:"retryCount"`
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
InputRef string `orm:"input_ref" json:"inputRef"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
TextResult map[string]any `orm:"text_result" json:"text"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID用于标识同一轮次的任务
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
}

View File

@@ -1,7 +1,7 @@
package entity
import (
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
)
type LogsModelPpCol struct {

View File

@@ -1,67 +0,0 @@
package service
import (
"context"
"encoding/json"
"model-gateway/model/entity"
"gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
)
// triggerCallback 任务成功后的回调:
// - JSON body 参数task_id/state/oss_file/file_type/text可选
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
callbackURL := t.BizName + t.CallbackURL
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"task_id": t.TaskID,
"state": t.State,
"oss_file": t.OssFile,
"file_type": t.FileType,
"text": t.TextResult,
"error_msg": t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
return
}
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.TaskID, callbackURL, len(headers), len(jsonData))
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
return
}
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
}
// triggerPromptsCallback 任务成功后的提示词回调
// - JSON body 参数epicycleId轮次id/textResult模型回答消息
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
callbackURL := "prompts-core/session/sessionCallback"
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"epicycleId": epicycleId,
"text": t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.EpicycleId, callbackURL, len(headers), len(jsonData))
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
}

View File

@@ -1,47 +0,0 @@
package service
import (
"net/http"
"strings"
)
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
// http.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx])
}
switch ct {
case "audio/mpeg":
return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav"
case "video/mp4":
return ct, ".mp4"
case "image/png":
return ct, ".png"
case "image/jpeg":
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
case "text/plain":
return ct, ".txt"
case "application/json":
return ct, ".json"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx])
}
return ct, "." + sub
}
return ct, ""
}
}

View File

@@ -0,0 +1,195 @@
package gateway
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"model-gateway/common/util"
"model-gateway/model/entity"
"time"
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/guid"
)
type UploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return nil, err
}
if _, err := part.Write(data); err != nil {
return nil, err
}
contentType := writer.FormDataContentType()
if err = writer.Close(); err != nil {
return nil, err
}
headers := util.ForwardHeaders(ctx)
headers["Content-Type"] = contentType
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp UploadFileResponse
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return nil, err
}
if &resp == nil {
return nil, errors.New("[OSS] 上传文件失败")
}
g.Log().Infof(ctx, "[OSS] 上传成功 url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return &resp, nil
}
// CallbackPayload 回调请求体
type CallbackPayload struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Messages map[string]any `json:"messages"`
ErrorMsg string `json:"error_msg"`
}
// TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
headers := util.ForwardHeaders(ctx)
var resp struct{}
payload := CallbackPayload{
TaskId: t.TaskID,
State: t.State,
OssFile: t.OssFile,
FileType: t.FileType,
Messages: t.TextResult,
ErrorMsg: t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
return
}
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
err = commonHttp.Post(ctx, t.CallbackURL, headers, &resp, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
return
}
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
}
// PromptsCallbackPayload 提示词回调请求体
type PromptsCallbackPayload struct {
EpicycleId int64 `json:"epicycleId"`
Messages map[string]any `json:"messages"`
}
// TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx)
var resp struct{}
payload := PromptsCallbackPayload{
EpicycleId: epicycleId,
Messages: t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.EpicycleId, callbackURL, len(headers), len(jsonData))
err = commonHttp.Post(ctx, callbackURL, headers, &resp, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := util.ForwardHeaders(ctx)
var r = make(map[string]bool)
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
// if callbackURL == "" {
// return
// }
//
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
// if task == nil {
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
// return
// }
//
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
// for i := range detailList {
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
// }
//
// // 构建与查询接口一致的 taskInfo
// taskInfo := dao.EntityToItem(task)
//
// // 兼容历史数据: 从 result 中补全 scenes 等字段
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
//
// payload := dto.CallbackPayload{
// TaskInfo: taskInfo,
// DetailList: detailItems,
// }
//
// body, _ := json.Marshal(payload)
//
// // 透传调用方的用户信息
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
//
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
// req.Header.Set("Content-Type", "application/json")
// req.Header.Set("X-User-Info", string(userJSON))
//
// resp, reqErr := http.DefaultClient.Do(req)
// if reqErr != nil {
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
// return
// }
// defer resp.Body.Close()
//
// respBody, _ := io.ReadAll(resp.Body)
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
//}

View File

@@ -1,53 +0,0 @@
package service
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// asyncCtx 固化异步执行所需的 token/user避免请求结束后丢失仅在“同请求内起 goroutine”有用
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// 兜底:从请求头拿
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

View File

@@ -1,7 +1,10 @@
package service
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
"model-gateway/dao"
@@ -14,35 +17,36 @@ var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) {
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
deleteTmpResult(t.TmpFile)
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
@@ -51,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) {
if err != nil || m == nil {
continue
}
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
if limit > 0 {
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
if !ok {
continue
}
@@ -73,20 +77,23 @@ func (c *cleaner) RunOnce(ctx context.Context) {
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
deleteTmpResult(t.TmpFile)
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -0,0 +1,256 @@
package model
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
// Create 创建模型
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
// 1如果设为会话模型先把该用户旧会话模型取消
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
if err := s.clearUserChatModel(ctx); err != nil {
return nil, err
}
}
// 2判断是否超管决定 isOwner
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
// Update 更新模型配置
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
// 1会话模型唯一性校验
if req.IsChatModel != nil && *req.IsChatModel == 1 {
if err := s.checkChatModelUnique(ctx); err != nil {
return err
}
}
// 2超管创建/普通用户更新
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
// Get 获取模型详情
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
if g.IsEmpty(req.ID) {
req.Creator = user.UserName
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Id: req.ID,
Creator: user.UserName,
},
ModelName: req.ModelName,
IsChatModel: req.IsChatModel,
})
if err != nil {
return nil, err
}
return &dto.GetModelRes{
Model: model,
}, nil
}
// List 获取模型列表
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
// 1判断超管
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 2获取当前用户
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
req.Creator = user.UserName
// 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return nil, err
}
return &dto.ListModelRes{List: models, Total: total}, nil
}
// UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil || newModel == nil {
return errors.New("新会话模型不存在")
}
// 2获取当前用户的会话模型
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil {
return err
}
// 3事务取消旧的 + 设置新的
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
})
}
// GetIsChatModel 获取当前用户会话模型
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil, err
}
return &dto.GetIsChatModelRes{Model: model}, nil
}
// ==================== 辅助方法 ====================
// clearUserChatModel 清除当前用户旧会话模型
func (s *modelService) clearUserChatModel(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
return err
}
// checkChatModelUnique 校验用户是否已有会话模型
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型")
}
return nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
return &dto.TypeItem{
Type: types,
}, nil
}
// GetOperatorList 获取运营商列表
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
return &dto.ListOperatorRes{
List: public.OperatorList,
}, nil
}

View File

@@ -1,417 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func parseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
// ============ 新增:请求参数映射 ============
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
if err != nil {
return nil, fmt.Errorf("请求参数映射失败: %w", err)
}
url := strings.TrimRight(m.BaseURL, "/")
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 先注入模型配置 head_msg静态头部适合公共模型固定 API Key
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
// 最后注入动态 modelKey允许覆盖/补充静态 head_msg适合按请求动态传密钥。
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// ============ 新增:响应参数映射 ============
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
if err != nil {
// 响应映射失败不阻塞,返回原始数据
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
// =========================================
return mappedResponse, nil
}
// ============================================
// 映射相关函数
// ============================================
// mapRequestPayload 将标准请求映射为模型特定格式
func mapRequestPayload(mappingAny any, payload any) (any, error) {
// 1. 解析请求映射配置值是any类型支持bool、number等
mapping, err := parseRequestMapping(mappingAny)
if err != nil {
return nil, err
}
// 如果没有映射配置直接返回原始payload
if len(mapping) == 0 {
return payload, nil
}
// 2. 将payload转为map
var payloadMap map[string]any
switch v := payload.(type) {
case map[string]any:
payloadMap = v
case []map[string]any:
// 如果传进来的是纯messages数组包装成标准格式
payloadMap = map[string]any{
"messages": v,
}
default:
// 通过JSON转换
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化payload失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
return nil, fmt.Errorf("反序列化payload失败: %w", err)
}
}
// 3. 用数据库固定参数覆盖/补充
for key, value := range mapping {
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
payloadMap[key] = value
}
}
return payloadMap, nil
}
// mapResponsePayload 将模型响应映射为标准格式
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
mapping, err := parseResponseMapping(mappingAny)
if err != nil {
return nil, err
}
if len(mapping) == 0 {
return responseBytes, nil
}
responseStr := string(responseBytes)
resultStr := `{}`
for standardField, modelPath := range mapping {
value := gjson.Get(responseStr, modelPath)
if !value.Exists() {
continue
}
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
if err != nil {
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
}
}
return []byte(resultStr), nil
}
func parseRequestMapping(mappingAny any) (map[string]any, error) {
if mappingAny == nil {
return nil, nil
}
result := make(map[string]any)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
// 尝试转成 map
if m := v.Map(); m != nil {
for k, val := range m {
result[k] = val
}
return result, nil
}
// 尝试转成 string
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
return result, nil
}
return nil, nil
// =======================================================
case map[string]interface{}:
result = v
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &result); err != nil {
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("解析映射配置失败: %w", err)
}
}
return result, nil
}
// parseResponseMapping 解析响应映射配置
// 返回值类型为 map[string]string值都是JSON路径字符串
func parseResponseMapping(mappingAny any) (map[string]string, error) {
if mappingAny == nil {
return nil, nil
}
mapping := make(map[string]string)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
if m := v.Map(); m != nil {
for k, val := range m {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
return mapping, nil
}
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
return mapping, nil
}
return nil, nil
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
case map[string]interface{}:
// 数据库JSONB直接返回的map
for k, val := range v {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
}
}
return mapping, nil
}
// isEmptyValue 判断值是否为空
func isEmptyValue(v any) bool {
if v == nil {
return true
}
switch val := v.(type) {
case string:
return val == ""
case []any:
return len(val) == 0
case map[string]any:
return len(val) == 0
default:
return false
}
}

View File

@@ -1,255 +0,0 @@
package service
import (
"context"
"errors"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := forwardHeaders(ctx)
var r = make(map[string]bool)
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var model *entity.AsynchModel
model, err = dao.Model.GetByIsChatModel(ctx)
if err != nil {
return nil, err
}
// 如果有会话模型,那就改变为 0
if model != nil {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: model.Id,
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return nil, err
}
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
id, err := dao.Model.Insert(ctx, req)
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
//判断当前用户是否有会话模型
model, err := dao.Model.GetByIsChatModel(ctx)
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型,不能创建")
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
if err != nil {
return err
}
if admin {
req.IsOwner = gconv.PtrInt(0)
_, err = dao.Model.Update(ctx, req)
if err != nil {
return err
}
return nil
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建否则更新
var count int
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
ID: req.ID,
Creator: user.UserName,
})
if err != nil {
return err
}
if count == 0 {
insertDto := new(dto.CreateModelReq)
err = gconv.Struct(req, insertDto)
if err != nil {
return err
}
_, err = dao.Model.Insert(ctx, insertDto)
return err
}
_, err = dao.Model.Update(ctx, req)
return err
}
func (s *modelService) Delete(ctx context.Context, id string) error {
_, err := dao.Model.DeleteByID(ctx, id)
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
model, err := dao.Model.Get(ctx, id)
if err != nil {
return nil, err
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
var models []*entity.AsynchModel
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, 0, err
}
req.Creator = user.UserName
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = ParseJSONField(m.Form)
m.RequestMapping = ParseJSONField(m.RequestMapping)
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
m.ResponseBody = ParseJSONField(m.ResponseBody)
}
return models, total, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
typeMap := make(map[int]string)
// 读取配置
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
for k, v := range configMap {
typeID := gconv.Int(k)
typeName := gconv.String(v)
if typeID > 0 && typeName != "" {
typeMap[typeID] = typeName
}
}
// 如果配置为空,使用默认值
if len(typeMap) == 0 {
typeMap = map[int]string{
1: "推理模型",
2: "图片模型",
3: "音频模型",
4: "向量化模型",
5: "全模态模型",
}
}
return typeMap
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 校验新会话模型是否存在
newModel, err := dao.Model.Get(ctx, req.Id)
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
// 获取当前用户会话模型
currentModel, err := dao.Model.GetByIsChatModel(ctx)
if err != nil {
return err
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != 1 {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: currentModel.Id,
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: req.Id,
IsChatModel: gconv.PtrInt(1),
})
return err
})
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
model, err := dao.Model.GetByIsChatModel(ctx)
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}

View File

@@ -1,25 +0,0 @@
package service
import "github.com/gogf/gf/v2/util/gconv"
// parseStoredPayload 解析入库的 request_payload拆出模型调用 payload 与透传 headers
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
func parseStoredPayload(v any) (payload any, headers map[string]string) {
if v == nil {
return nil, nil
}
m := gconv.Map(v)
if len(m) == 0 {
return v, nil
}
if h, ok := m["headers"]; ok {
headers = gconv.MapStrStr(h)
}
if p, ok := m["payload"]; ok {
payload = p
} else {
payload = v
}
return
}

View File

@@ -1,14 +1,16 @@
package service
package queue
import (
"context"
"errors"
"fmt"
"math"
"model-gateway/model/dto"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
)
@@ -26,7 +28,6 @@ type AutoTuneResult struct {
OldQueueLimit int `json:"oldQueueLimit"` // 调参前运行时值Redis若无则等于 cap
NewQueueLimit int `json:"newQueueLimit"` // 本次计算出的运行时值(将写入 Redis受 ±50% 约束且不超过 cap
ExpectedSeconds int `json:"expectedSeconds"` // 模型预计执行时间asynch_models.expected_seconds用于 queue_limit 计算绑定)
}
// AutoTune 由上层定时任务通过接口触发:
@@ -34,9 +35,12 @@ type AutoTuneResult struct {
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2生成运行时值不超过 cap
// - 单次调整幅度限制 ±50%,写入 Redis带 TTL
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
if windowSeconds <= 0 {
windowSeconds = 3600
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
if req == nil {
return nil, errors.New("request cannot be nil")
}
if req.WindowSeconds <= 0 {
req.WindowSeconds = 3600 // 默认1小时
}
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel
@@ -60,15 +64,15 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
if m.MaxConcurrency > cur.MaxConcurrency {
cur.MaxConcurrency = m.MaxConcurrency
}
if m.QueueLimit > cur.QueueLimit {
cur.QueueLimit = m.QueueLimit
if m.MaxConcurrency*2 > cur.MaxConcurrency*2 {
cur.MaxConcurrency = m.MaxConcurrency
}
if m.ExpectedSeconds > cur.ExpectedSeconds {
cur.ExpectedSeconds = m.ExpectedSeconds
if m.TimeoutSeconds > cur.TimeoutSeconds {
cur.TimeoutSeconds = m.TimeoutSeconds
}
}
if len(modelMap) == 0 {
return []AutoTuneResult{}, nil
return nil, errors.New("no models found")
}
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
@@ -89,7 +93,7 @@ SELECT model_name,
AND finished_at IS NOT NULL
AND finished_at >= (NOW() - (? || ' seconds')::interval)
GROUP BY model_name`, public.TableNameTask)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
if err != nil {
return nil, err
}
@@ -108,7 +112,7 @@ SELECT model_name,
for modelName, m := range modelMap {
s := statMap[modelName]
capMax := m.MaxConcurrency
capQueue := m.QueueLimit
capQueue := m.MaxConcurrency * 2
oldMax := GetRuntimeMaxConcurrency(ctx, modelName, capMax)
oldQueue := GetRuntimeQueueLimit(ctx, modelName, capQueue)
@@ -124,7 +128,6 @@ SELECT model_name,
CapQueueLimit: capQueue,
OldQueueLimit: oldQueue,
NewQueueLimit: oldQueue,
ExpectedSeconds: m.ExpectedSeconds,
})
continue
}
@@ -150,7 +153,7 @@ SELECT model_name,
setRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName), newMax)
// queue_limitW_target = expected_seconds * queueFactor
exp := m.ExpectedSeconds
exp := m.TimeoutSeconds
if exp <= 0 {
exp = 60
}
@@ -185,10 +188,11 @@ SELECT model_name,
CapQueueLimit: capQueue,
OldQueueLimit: oldQueue,
NewQueueLimit: newQueue,
ExpectedSeconds: m.ExpectedSeconds,
})
}
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
return out, nil
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
return &dto.AutoTuneRes{
List: out,
}, nil
}

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"
@@ -11,9 +11,9 @@ import (
// 上层每小时调用 /model/autoTune 写入运行时值Worker/CreateTask 读取运行时值生效。
const (
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
runtimeTTLSeconds = 2 * 3600 // 2小时避免一次调参失败导致立即回退
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
runtimeTTLSeconds = 2 * 3600 // 2小时避免一次调参失败导致立即回退
)
func runtimeMaxConcurrencyKey(modelName string) string {
@@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int {
}
return v
}

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"
@@ -34,7 +34,8 @@ end
return 1
`
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
// AcquireSemaphore 获取并发令牌
func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
if max <= 0 {
// 不限制
return true, nil
@@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64
return gconv.Int(r) == 1, nil
}
func releaseSemaphore(ctx context.Context, key string) error {
// ReleaseSemaphore 释放并发令牌
func ReleaseSemaphore(ctx context.Context, key string) error {
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
return err
}

View File

@@ -1,4 +1,4 @@
package service
package stat
import (
"context"

View File

@@ -1,18 +0,0 @@
package service
import (
"context"
"errors"
"model-gateway/model/entity"
)
// StorageService 结果存储OSS/MinIO抽象
type StorageService interface {
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
}
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO
var Storage StorageService = &ossStorage{}
var ErrStorageNotConfigured = errors.New("存储未配置")

View File

@@ -1,81 +0,0 @@
package service
import (
"bytes"
"context"
"fmt"
"mime/multipart"
"time"
"model-gateway/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/guid"
)
// 对接你们的 oss 文件服务POST oss/file/uploadFile (multipart/form-data)
type ossStorage struct{}
type uploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return "", err
}
headers := forwardHeaders(ctx)
headers["Content-Type"] = contentType
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp uploadFileResponse
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx给 worker 调 OSS 用
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
if v := gconv.String(headers["Authorization"]); v != "" {
ctx = context.WithValue(ctx, "token", v)
}
if v := gconv.String(headers["X-User-Info"]); v != "" {
ctx = context.WithValue(ctx, "xUserInfo", v)
}
return ctx
}

View File

@@ -0,0 +1,305 @@
package task
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/service/queue"
"time"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
},
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
if model == nil || (model.Enabled != nil && *model.Enabled != 1) {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限严格控制Redis 原子闸门)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
if limit > 0 {
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
// 3) 插入任务记录
if model.CallMode != nil && *model.CallMode == public.CallModeAsync {
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
storedPayload := map[string]any{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
})
if err != nil { // 入库失败:回滚闸门占位
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 4) 写操作日志(不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = utils.GetLocalIP()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
HttpMethod: httpMethod,
BizName: req.BizName,
ModelName: req.ModelName,
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
}
// 2. 成功:取 video_url 和 usage
if req.Status == "succeeded" {
result := map[string]any{
"video_url": req.Content["video_url"],
"usage": req.Usage,
}
NotifyAsyncResult(req.TaskID, result, nil)
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
// 3. 失败/过期
if req.Status == "failed" || req.Status == "expired" {
NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status))
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
return &dto.ModelTaskCallbackRes{Success: true}, nil
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) {
limit := req.Limit
if limit <= 0 {
limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
// 2. 逐个查询
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
result, err := util.PullTaskResult(ctx, nil, model.QueryConfig, model.HeadMsg)
if err != nil {
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
continue
}
status := gconv.String(result["status"])
item := dto.QueryTaskItem{
TaskID: t.TaskID,
Status: status,
Content: result["content"].(map[string]any),
Usage: result["usage"].(map[string]any),
}
results = append(results, item)
// 如果任务完成,通知等待通道
if status == "succeeded" || status == "failed" || status == "expired" {
NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil)
}
}
return &dto.QueryPendingTasksRes{
Total: len(results),
Results: results,
}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

494
service/task/worker.go Normal file
View File

@@ -0,0 +1,494 @@
package task
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"unicode/utf8"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"model-gateway/service/queue"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
maxRetry := model.RetryTimes // 重试次数
startTime := time.Now()
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// 1) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
task.DurationSeconds = int64(time.Since(startTime).Seconds())
w.failTask(ctx, task, startTime, err.Error())
return
}
if !acquired {
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// 2) 调用模型
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err := w.callModelStream(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
default:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
}
// 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
if err == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = body
w.failTask(ctx, task, startTime, err.Error())
return
}
// 5) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = w.uploadOSS(ctx, task)
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 6) 成功回调
task.State = 2
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL
task.FileType = oss.FileFormat
task.TextResult = body
task.FileSize = int64(oss.FileSize)
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return
}
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
// 7) 删除临时文件
_ = os.Remove(task.TmpFile)
}
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
var data []byte
var err error
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
return data, nil
}
// asyncResult 异步任务结果
type asyncResult struct {
result map[string]any
err error
}
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body)
if err != nil {
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
asyncTaskChan.Store(taskID, ch)
defer func() {
asyncTaskChan.Delete(taskID)
close(ch)
}()
// 4. 阻塞等待回调或超时
timeout := time.Duration(model.TimeoutSeconds) * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
select {
case res, ok := <-ch:
if !ok {
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
}
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
return res.result, res.err
case <-ctx.Done():
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
}
}
// NotifyAsyncResult 回调接口调用此方法通知结果
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
if ch, ok := asyncTaskChan.Load(taskID); ok {
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
}
}
// callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
var data []byte
var err error
// 1) 如果已有临时文件且 phase=1直接读取
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
data = nil
}
}
// 2) 没有可用数据,调用模型
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil {
return nil, err
}
// 3) 检测文件类型,保存临时文件
_, ext := util.DetectFileType(data)
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
// 4) 检测文件类型,提取文本结果
contentType, _ := util.DetectFileType(data)
var textResult string
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
// 5) 非文本内容,返回错误
if textResult == "" {
return nil, fmt.Errorf("模型返回非文本内容contentType=%s", contentType)
}
// 6) 解析并返回
return gjson.New(textResult).Map(), nil
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
}
continue
}
// 2) 先存 token 到数据库,防止后续失败丢失
if tokens, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(tokens)
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
})
}
// 3) 解析 + 校验
var parsed map[string]any
switch req.BuildType {
case public.BuildTypePrompt, public.BuildTypeNode:
parsed, err = util.ParseAndValidate(mapped, model)
if err == nil {
return parsed, nil
}
case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody)
return parsed, nil
default:
return mapped, nil
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
}
// 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
reqBody := util.GetModelBody(task.RequestPayload)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
body = rawResp
}
return body, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
// 1请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.BodyToQuery(body)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(baseURL, "?") {
baseURL = baseURL + "&" + q.Encode()
} else {
baseURL = baseURL + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 4注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
if modelKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 5发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 6读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 7检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}
// // InvokeModel 调用模型服务,返回二进制结果
//
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}

View File

@@ -1,266 +0,0 @@
package service
import (
"context"
"errors"
"fmt"
"time"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
fmt.Printf("打印请求:%+v", req)
startAt := time.Now()
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, err
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
return nil, errors.New("模型不存在或未启用")
}
taskID := uuid.NewString()
// 2) 排队上限严格控制Redis 原子闸门)
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
if limit > 0 {
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: m.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
// 入库失败:回滚闸门占位
ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 3) 写操作日志(尽量不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = r.GetClientIp()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
HttpMethod: httpMethod,
BizName: req.BizName,
ModelName: req.ModelName,
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 这样不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
if taskID == "" {
return
}
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
if interval <= 0 {
interval = 5
}
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
tryRun := func() bool {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
return true
}
switch t.State {
case 0:
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
return true
}
}
// 先立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-ctx.Done():
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
}
}
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

View File

@@ -1,38 +0,0 @@
package service
import (
"fmt"
"os"
"path/filepath"
)
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func loadTmpResult(path string) ([]byte, error) {
return os.ReadFile(path)
}
func deleteTmpResult(path string) {
if path == "" {
return
}
_ = os.Remove(path)
}

View File

@@ -1,113 +0,0 @@
package service
import (
"encoding/json"
"strings"
"github.com/gogf/gf/v2/container/gvar"
)
func normalizeFormValue(v any) any {
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
if v == nil {
return []any{}
}
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
case []byte:
if len(t) == 0 {
return []any{}
}
return normalizeFormValueFromJSONBytes(t)
case *gvar.Var:
// goframe 常见的 DB 返回类型
if t == nil {
return []any{}
}
b := t.Bytes()
if len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
s := strings.TrimSpace(t.String())
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
default:
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
if vb, ok := v.(interface{ Bytes() []byte }); ok {
if b := vb.Bytes(); len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
}
if vs, ok := v.(interface{ String() string }); ok {
if s := strings.TrimSpace(vs.String()); s != "" {
return normalizeFormValueFromJSONString(s)
}
}
// 已经是 []any / map[string]any 等结构
return v
}
}
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
func normalizeFormValueFromJSONString(s string) any {
var out any
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
return []any{}
}
// 如果解出来还是 string且看起来是 JSON再解一层
if inner, ok := out.(string); ok {
inner = strings.TrimSpace(inner)
if inner == "" {
return []any{}
}
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
var out2 any
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
return out2
}
}
return []any{}
}
return out
}
func normalizeFormValueFromJSONBytes(b []byte) any {
var out any
if err := json.Unmarshal(b, &out); err != nil || out == nil {
return []any{}
}
// bytes 解出来也可能是 string同上
if inner, ok := out.(string); ok {
return normalizeFormValueFromJSONString(inner)
}
return out
}
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}

View File

@@ -1,246 +0,0 @@
package service
import (
"context"
"fmt"
"strings"
"time"
"unicode/utf8"
"model-gateway/dao"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/tidwall/gjson"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
}
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
if batchSize <= 0 {
batchSize = 10
}
if goroutines <= 0 {
goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
return 0, err
}
if len(tasks) == 0 {
return 0, nil
}
pool := grpool.New(goroutines)
defer pool.Close()
claimed = len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task, 0)
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
}
for i := 0; i < claimed; i++ {
<-done
}
return claimed, nil
}
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, epicycleId)
return nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
// 从任务入库的 request_payload 里恢复 payload + headers
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
errMsg := "模型不存在或未启用"
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = errMsg
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
// 2) 分布式并发限制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
leaseSeconds := int64(3600)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
if !acquired {
// 并发满了:放回排队,不回调(不是失败)
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
var (
data []byte
contentType string
ext string
textResult string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
data = nil
}
}
if data == nil {
// 统计
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
// 核心调用
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
contentType, ext = DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ OSS失败不回调还会重试 ============
// 注意OSS失败保留临时文件下次重试所以这里不触发最终回调
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
return
}
// 5) 更新任务状态成功
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(
ctx,
t.Id,
ossURL,
fileType,
textResult,
int64(len(data)),
nil,
GetExpendTokens(m.TokenMapping, textResult),
); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功/失败均不再占用 queue_limit
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// 6) 成功回调
t.State = 2
t.OssFile = ossURL
t.FileType = fileType
t.TextResult = textResult
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
go triggerCallback(context.WithoutCancel(ctx), t)
// ============ 如果有 epicycleId也触发业务回调 ============
if epicycleId != 0 {
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
}
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
func GetExpendTokens(tokenMapping string, textResult string) int {
value := gjson.Get(textResult, tokenMapping)
if value.Exists() {
return int(value.Int())
} else {
return len(textResult)
}
}

Binary file not shown.

Binary file not shown.

View File

@@ -1 +0,0 @@
Asia/Shanghai

View File

@@ -8,41 +8,57 @@
-- 1) asynch_models
-- =========================
CREATE TABLE IF NOT EXISTS asynch_models (
-- 基础字段
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL, -- 创建人
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updater VARCHAR(64) NOT NULL, -- 更新人
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
deleted_at TIMESTAMP(6), -- 删除时间(软删)
-- 业务字段
model_name VARCHAR(128) NOT NULL, -- 模型名称
model_type SMALLINT NOT NULL DEFAULT 0, -- 模型类型
base_url VARCHAR(256) NOT NULL, -- 模型地址
http_method VARCHAR(8) NOT NULL DEFAULT 'POST', -- 请求方式 GET/POST
head_msg VARCHAR(1024) DEFAULT '', -- 请求头绑定(支持多个,逗号分隔)示例 X-API:xxx,operation:true
is_private SMALLINT NOT NULL DEFAULT 0, -- 是否私有化 0-私有 1-公共
enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用 0停用 1-启用
is_chat_model SMALLINT NOT NULL DEFAULT 0, -- 是否为对话模型 0-否 1-是
is_owner SMALLINT NOT NULL DEFAULT 99, -- 1=当前用户创建的0=超级管理员的
api_key VARCHAR(256) NOT NULL DEFAULT '', -- 调用凭证,密钥
prompt TEXT NOT NULL DEFAULT '', -- 提示词内容(文本)
form_json JSONB NOT NULL DEFAULT '{}'::jsonb, -- 表单结构(用于前端渲染)
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb -- 请求映射
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, -- 返回映射
response_body JSONB NOT NULL DEFAULT '{}'::jsonb, -- 返回主体
max_concurrency INT NOT NULL DEFAULT 10, -- 单模型最大并发
queue_limit INT NOT NULL DEFAULT 1000, -- 排队上限(近似控制)
timeout_seconds INT NOT NULL DEFAULT 600, -- 调用模型服务超时(秒)
expected_seconds INT NOT NULL DEFAULT 600, -- 模型预计执行时间(秒)
retry_times SMALLINT NOT NULL DEFAULT 3, -- 失败重试次数
retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)
auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理)
remark TEXT DEFAULT '' -- 备注
token_mapping VARCHAR(128) NOT NULL DEFAULT ''; -- token 映射
);
-- ========== 基础字段 ==========
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6),
-- ========== 模型标识 ==========
model_name VARCHAR(128) NOT NULL,
model_type SMALLINT NOT NULL DEFAULT 0,
operator_name VARCHAR(64) NOT NULL DEFAULT '',
-- ========== 请求配置 ==========
base_url VARCHAR(256) NOT NULL,
http_method VARCHAR(8) NOT NULL DEFAULT 'POST',
head_msg JSONB NOT NULL DEFAULT '{}'::jsonb,
api_key VARCHAR(256) NOT NULL DEFAULT '',
-- ========== 状态开关 ==========
is_private SMALLINT NOT NULL DEFAULT 0,
enabled SMALLINT NOT NULL DEFAULT 1,
is_chat_model SMALLINT NOT NULL DEFAULT 0,
is_async SMALLINT NOT NULL DEFAULT 0,
is_stream SMALLINT NOT NULL DEFAULT 0,
is_owner SMALLINT NOT NULL DEFAULT 99,
-- ========== 配置相关 ==========
form_json JSONB NOT NULL DEFAULT '{}'::jsonb,
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
response_body JSONB NOT NULL DEFAULT '{}'::jsonb,
token_config JSONB NOT NULL DEFAULT '{}'::jsonb,
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
query_config JSONB NOT NULL DEFAULT '{}'::jsonb,
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb,
first_frame VARCHAR(128) NOT NULL DEFAULT '',
last_frame VARCHAR(128) NOT NULL DEFAULT '',
-- ========== 限制与重试 ==========
max_concurrency INT NOT NULL DEFAULT 10,
timeout_seconds INT NOT NULL DEFAULT 600,
retry_times SMALLINT NOT NULL DEFAULT 3,
auto_clean_seconds INT NOT NULL DEFAULT 86400,
-- ========== 其他 ==========
response_token_field VARCHAR(128) NOT NULL DEFAULT '',
);
-- ========== 索引 ==========
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
@@ -51,7 +67,9 @@ CREATE INDEX IF NOT EXISTS idx_asynch_models_model_type ON asynch_models(model_t
CREATE INDEX IF NOT EXISTS idx_asynch_models_enabled ON asynch_models(enabled);
CREATE INDEX IF NOT EXISTS idx_asynch_models_deleted_at ON asynch_models(deleted_at);
-- ========== 注释 ==========
COMMENT ON TABLE asynch_models IS '模型配置表';
COMMENT ON COLUMN asynch_models.id IS '主键ID(非自增)';
COMMENT ON COLUMN asynch_models.tenant_id IS '租户ID';
COMMENT ON COLUMN asynch_models.creator IS '创建人';
@@ -62,29 +80,32 @@ COMMENT ON COLUMN asynch_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN asynch_models.model_name IS '模型名称';
COMMENT ON COLUMN asynch_models.model_type IS '模型类型';
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称';
COMMENT ON COLUMN asynch_models.base_url IS '模型地址';
COMMENT ON COLUMN asynch_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN asynch_models.head_msg IS '请求头绑定(支持多个,逗号分隔)示例 X-API:xxx,operation:true';
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化 0-私有 1-公共';
COMMENT ON COLUMN asynch_models.enabled IS '是否启用 0停用 1-启用';
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型 0-否 1-';
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建的0=超级管理员的';
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证,密钥';
COMMENT ON COLUMN asynch_models.prompt IS '提示词内容(文本)';
COMMENT ON COLUMN asynch_models.form_json IS '表单结构(用于前端渲染,也用于后端校验)';
COMMENT ON COLUMN asynch_models.head_msg IS '请求头信息';
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN asynch_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN asynch_models.is_async IS '是否异步0-同步 1-异步';
COMMENT ON COLUMN asynch_models.is_stream IS '是否流式0-非流式 1-流式';
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN asynch_models.form_json IS '动态表单结构';
COMMENT ON COLUMN asynch_models.request_mapping IS '请求映射';
COMMENT ON COLUMN asynch_models.response_mapping IS '返回映射';
COMMENT ON COLUMN asynch_models.response_body IS '返回主体';
COMMENT ON COLUMN asynch_models.max_concurrency IS '单模型最大并发';
COMMENT ON COLUMN asynch_models.queue_limit IS '排队上限(近似控制)';
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型服务超时(秒)';
COMMENT ON COLUMN asynch_models.expected_seconds IS '模型预计执行时间(秒)';
COMMENT ON COLUMN asynch_models.token_config IS 'Token计算配置';
COMMENT ON COLUMN asynch_models.extend_mapping IS '附加映射';
COMMENT ON COLUMN asynch_models.query_config IS '查询/回调配置';
COMMENT ON COLUMN asynch_models.stream_config IS '流式输出配置';
COMMENT ON COLUMN asynch_models.first_frame IS '首帧图片参数';
COMMENT ON COLUMN asynch_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN asynch_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)';
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)';
COMMENT ON COLUMN asynch_models.remark IS '备注';
COMMENT ON COLUMN asynch_models.token_mapping IS 'token映射';
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
-- =========================