Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 196d2069ac | |||
| 7596cbde09 | |||
| 7ec18926e3 | |||
| a6b32bfeb3 | |||
| 2dc88ae587 | |||
| e906248b0a | |||
| e5781aca06 | |||
| 0cf8948cd2 | |||
| 96e8bdfe62 | |||
| 26de41d04e | |||
| 0bee3685fb | |||
| 9049e0d2e8 | |||
| aae46a4f29 | |||
| bcfcc7ed47 | |||
| 2c7838807b | |||
| 52124385a1 | |||
| c7e9eb889b | |||
| 558fd49ec1 | |||
| d409b84b58 | |||
| e487b4bb5e | |||
| a28fcbaee9 | |||
| 5416e7a983 | |||
| 0e2ac286e9 | |||
| a88dc84d99 | |||
| 4d2d4fd93d | |||
| 7129bd2de7 | |||
| 09474eb997 | |||
| 4946220185 | |||
| b6cdb8ff1d | |||
| 4626d819b5 | |||
| 170568e03e | |||
| a080a5536d | |||
| 142fea1e91 | |||
| a585233c4d |
30
Dockerfile
30
Dockerfile
@@ -1,43 +1,23 @@
|
||||
# 多阶段构建 - 第一阶段:编译(使用已安装的镜像)
|
||||
FROM golang:1.26-alpine3.23 AS builder
|
||||
# 阶段1: 构建
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git ca-certificates tzdata
|
||||
|
||||
ENV TZ=Asia/Shanghai
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
ENV GO111MODULE=on
|
||||
ENV GOPROXY=https://goproxy.cn,direct
|
||||
ENV CGO_ENABLED=0
|
||||
ENV GOTOOLCHAIN=auto
|
||||
ENV GOPRIVATE=gitea.com/red-future/common
|
||||
|
||||
# 配置git使用私有Gitea仓库(带Token认证)
|
||||
RUN git config --global url."http://x-token-auth:619679cd366aefea3a50f0622d842a41f2209e08595767bba49c3836ef57d415@116.204.74.41:3000/red-future/common.git".insteadOf "https://gitea.com/red-future/common.git" && \
|
||||
git config --global credential.helper store
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# 复制父目录的 common 模块(因为 go.mod 中使用了本地 replace)
|
||||
#COPY ../common /build/common
|
||||
COPY . .
|
||||
|
||||
RUN go mod download && go mod tidy
|
||||
|
||||
RUN go build -ldflags="-s -w" -o main ./main.go
|
||||
|
||||
# 第二阶段:运行
|
||||
FROM alpine:3.23
|
||||
|
||||
ENV TIME_ZONE=Asia/Shanghai
|
||||
RUN apk add --no-cache ca-certificates tzdata && \
|
||||
ln -sf /usr/share/zoneinfo/$TIME_ZONE /etc/localtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制编译好的二进制文件
|
||||
COPY --from=builder /build/main .
|
||||
COPY --from=builder /build/config.yml ./
|
||||
|
||||
# 创建日志目录
|
||||
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
|
||||
|
||||
EXPOSE 3004
|
||||
|
||||
|
||||
10
common/util/convert.go
Normal file
10
common/util/convert.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package util
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// ConvertTo 转换为指定类型
|
||||
func ConvertTo[T any](v interface{}) *T {
|
||||
var t T
|
||||
_ = gconv.Struct(v, &t)
|
||||
return &t
|
||||
}
|
||||
115
common/util/files.go
Normal file
115
common/util/files.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// SaveTempFileByType
|
||||
// 根据传入的数据自动判断:
|
||||
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
|
||||
// 若是任意结构体/map → 自动转 JSON 保存
|
||||
// 返回:新临时文件路径、错误
|
||||
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
|
||||
// 1. 先清理旧临时文件(统一逻辑)
|
||||
if oldTmpFile != "" {
|
||||
_ = os.Remove(oldTmpFile)
|
||||
}
|
||||
|
||||
var tmpPath string
|
||||
var tmpErr error
|
||||
|
||||
// 2. 判断是否是二进制音频([]byte + .mp3)
|
||||
if audioData, ok := data.([]byte); ok {
|
||||
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
|
||||
} else {
|
||||
// 3. 其他类型 → 序列化为 JSON 保存
|
||||
mappedBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(mappedBytes) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
|
||||
}
|
||||
|
||||
if tmpErr != nil || tmpPath == "" {
|
||||
return "", tmpErr
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// saveTmpResult 你原有的底层保存文件方法(保留不动)
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
// 你原来实现,比如:
|
||||
filename := taskID + ext
|
||||
tmpPath := filepath.Join(os.TempDir(), filename)
|
||||
err := os.WriteFile(tmpPath, data, 0644)
|
||||
return tmpPath, err
|
||||
}
|
||||
79
common/util/headers.go
Normal file
79
common/util/headers.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||
func AsyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
|
||||
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
SetHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||
SetHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||
FallbackToRequestHeaders(headers, ctx)
|
||||
return headers
|
||||
}
|
||||
|
||||
// SetHeaderFromContext 从上下文中设置 header
|
||||
func SetHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||
headers[headerKey] = value
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackToRequestHeaders 从请求头中获取作为兜底
|
||||
func FallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
359
common/util/mapping.go
Normal file
359
common/util/mapping.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"model-gateway/model/entity"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
tgjson "github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ParseAndValidate 解析并校验结果
|
||||
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) {
|
||||
// 1) 解析 content 字符串为 rounds 数组
|
||||
contentVal, ok := raw[model.ResponseBody]
|
||||
if !ok {
|
||||
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
|
||||
}
|
||||
contentStr, ok := contentVal.(string)
|
||||
if !ok || strings.TrimSpace(contentStr) == "" {
|
||||
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
|
||||
}
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
||||
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
||||
}
|
||||
if len(arr) == 0 {
|
||||
return raw, fmt.Errorf("解析后数组为空")
|
||||
}
|
||||
|
||||
// 2) 校验必填字段
|
||||
if len(model.RequiredFields) > 0 {
|
||||
for i, r := range arr {
|
||||
round, ok := r.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, field := range model.RequiredFields {
|
||||
if gjson.New(round).Get(field).IsNil() {
|
||||
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
|
||||
}
|
||||
|
||||
// ParseStructResult 解析结构结果
|
||||
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||
contentVal := raw[responseBody]
|
||||
// 是字符串,尝试解析
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" || contentStr == "0" {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{{responseBody: raw}},
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{{responseBody: arr}},
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为单个对象
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{{responseBody: parsed}},
|
||||
}
|
||||
}
|
||||
|
||||
// 兜底:原始字符串作为内容
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{{responseBody: contentStr}},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
|
||||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||
// 1) 获取 rounds
|
||||
roundsRaw, ok := raw["rounds"]
|
||||
if !ok {
|
||||
return fmt.Errorf("缺少 rounds 字段")
|
||||
}
|
||||
rounds, ok := roundsRaw.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("rounds 不是数组")
|
||||
}
|
||||
if len(rounds) == 0 {
|
||||
return fmt.Errorf("rounds 数组为空")
|
||||
}
|
||||
|
||||
// 2) 没有配置必填字段,跳过
|
||||
if len(model.RequiredFields) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3) 逐条校验
|
||||
for i, r := range rounds {
|
||||
round, ok := r.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, field := range model.RequiredFields {
|
||||
if gjson.New(round).Get(field).IsNil() {
|
||||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateRequiredFields 校验单个 round 对象的必选字段
|
||||
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
|
||||
for _, field := range requiredFields {
|
||||
if gjson.New(round).Get(field).IsNil() {
|
||||
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
||||
// head_msg 格式示例:
|
||||
//
|
||||
// {
|
||||
// "Authorization": "Bearer xxx",
|
||||
// "Content-Type": "application/json",
|
||||
// "X-Api-App-Id": "5147401364",
|
||||
// "X-Api-Access-Key": "VCqRX7..."
|
||||
// }
|
||||
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
||||
if len(headMsg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(headMsg))
|
||||
for k, v := range headMsg {
|
||||
out[k] = gconv.String(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
|
||||
if len(mapping) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入
|
||||
resultBytes, _ := json.Marshal(result)
|
||||
resultStr := string(resultBytes)
|
||||
|
||||
mapped := make(map[string]any)
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
value := tgjson.Get(resultStr, path)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
// 如果是数组路径(含 #),取 Array;否则取单值
|
||||
if strings.Contains(path, "#") {
|
||||
var arr []any
|
||||
for _, v := range value.Array() {
|
||||
arr = append(arr, v.Value())
|
||||
}
|
||||
mapped[standardField] = arr
|
||||
} else {
|
||||
mapped[standardField] = value.Value()
|
||||
}
|
||||
}
|
||||
|
||||
return mapped, nil
|
||||
}
|
||||
|
||||
// GetModelBody 获取数据库中保存的模型信息
|
||||
func GetModelBody(v map[string]any) map[string]any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if p, ok := v["body"]; ok {
|
||||
return gconv.Map(p)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BodyToQuery 将 body 转为 url.Values
|
||||
func BodyToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
q.Set(k, gconv.String(v))
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// PullTaskResult 轮询查询异步任务结果直到完成
|
||||
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
|
||||
// 1) 解析配置
|
||||
// 1.1 提取 taskID
|
||||
taskIDPath := gconv.String(queryConfig["task_id"])
|
||||
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
|
||||
if taskID == "" {
|
||||
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
|
||||
}
|
||||
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
|
||||
|
||||
// 1.2 请求地址,替换 {id}
|
||||
queryUrl := gconv.String(queryConfig["url"])
|
||||
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
|
||||
|
||||
// 1.3 请求方式
|
||||
method := gconv.String(queryConfig["method"])
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
// 1.4 状态判断配置
|
||||
statusPath := gconv.String(queryConfig["status_path"])
|
||||
statusValues, _ := queryConfig["status_values"].(map[string]any)
|
||||
if statusPath == "" {
|
||||
statusPath = "status"
|
||||
}
|
||||
|
||||
// 1.5 轮询间隔
|
||||
interval := gconv.Int(queryConfig["interval_seconds"])
|
||||
if interval <= 0 {
|
||||
interval = 2
|
||||
}
|
||||
|
||||
// 1.6 请求体
|
||||
reqBodyMap := map[string]any{"task_id": taskID}
|
||||
|
||||
// 2) 轮询请求
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var reqBody io.Reader
|
||||
if method == "POST" {
|
||||
bs, _ := json.Marshal(reqBodyMap)
|
||||
reqBody = bytes.NewReader(bs)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 统一用 headMsg 注入请求头
|
||||
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(raw, &result)
|
||||
|
||||
statusVal := gjson.New(result).Get(statusPath).Val()
|
||||
statusStr := gconv.String(statusVal)
|
||||
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
|
||||
|
||||
if matchStatus(statusStr, statusValues["succeeded"]) {
|
||||
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if matchStatus(statusStr, statusValues["failed"]) {
|
||||
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
|
||||
return result, fmt.Errorf("任务失败")
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func matchStatus(actual string, expected any) bool {
|
||||
expectedStr := gconv.String(expected)
|
||||
if actual == expectedStr {
|
||||
return true
|
||||
}
|
||||
switch v := expected.(type) {
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if actual == gconv.String(item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// replaceURLParams 替换 URL 中的 {key}
|
||||
func replaceURLParams(url string, params map[string]any) string {
|
||||
re := regexp.MustCompile(`\{([^}]+)}`)
|
||||
return re.ReplaceAllStringFunc(url, func(s string) string {
|
||||
key := strings.Trim(s, "{}")
|
||||
if val, ok := params[key]; ok {
|
||||
return gconv.String(val)
|
||||
}
|
||||
return s
|
||||
})
|
||||
}
|
||||
|
||||
// InjectCallbackURL 将回调地址注入到请求体中
|
||||
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
||||
if callbackURL == "" {
|
||||
return payload
|
||||
}
|
||||
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
|
||||
return payload
|
||||
}
|
||||
150
common/util/streaming.go
Normal file
150
common/util/streaming.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
)
|
||||
|
||||
// ================================================================
|
||||
|
||||
// ParseStreamResponse 流式响应解析(通用入口)
|
||||
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
||||
enabled, _ := streamConfig["enabled"].(bool)
|
||||
if !enabled {
|
||||
return gjson.New(string(rawBytes)).Map(), nil
|
||||
}
|
||||
|
||||
parser, _ := streamConfig["parser"].(string)
|
||||
if parser == "base64_concat" {
|
||||
return parseBase64Stream(rawBytes)
|
||||
}
|
||||
|
||||
return parseSSEStream(rawBytes, streamConfig)
|
||||
}
|
||||
|
||||
// parseBase64Stream 拼接流式 base64 并解码为二进制(TTS 等音频模型)
|
||||
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
|
||||
lines := strings.Split(string(rawBytes), "\n")
|
||||
var audioBase64 strings.Builder
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if data, ok := chunk["data"].(string); ok && data != "" {
|
||||
audioBase64.WriteString(data)
|
||||
}
|
||||
}
|
||||
|
||||
cleanBase64 := strings.Map(func(r rune) rune {
|
||||
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, audioBase64.String())
|
||||
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
|
||||
if err != nil {
|
||||
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 解码失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"audio": audioBytes}, nil
|
||||
}
|
||||
|
||||
// parseSSEStream SSE 流式解析(图片模型等)
|
||||
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
||||
events, _ := streamConfig["events"].([]any)
|
||||
if len(events) == 0 {
|
||||
return gjson.New(string(rawBytes)).Map(), nil
|
||||
}
|
||||
|
||||
lines := strings.Split(string(rawBytes), "\n")
|
||||
result := make(map[string]any)
|
||||
var partials []map[string]any
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || line == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
line = strings.TrimPrefix(line, "data:")
|
||||
line = strings.TrimSpace(line)
|
||||
}
|
||||
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkType, _ := chunk["type"].(string)
|
||||
|
||||
for _, evt := range events {
|
||||
e, _ := evt.(map[string]any)
|
||||
match, _ := e["match"].(string)
|
||||
if !strings.Contains(chunkType, match) {
|
||||
continue
|
||||
}
|
||||
|
||||
fields, _ := e["fields"].(map[string]any)
|
||||
aggregateTo, _ := e["aggregate_to"].(string)
|
||||
evtType, _ := e["type"].(string)
|
||||
|
||||
switch evtType {
|
||||
case "partial":
|
||||
item := make(map[string]any)
|
||||
for localKey, chunkKey := range fields {
|
||||
item[localKey] = chunk[chunkKey.(string)]
|
||||
}
|
||||
partials = append(partials, item)
|
||||
|
||||
case "final":
|
||||
for localKey, chunkKey := range fields {
|
||||
val := gjson.New(chunk).Get(chunkKey.(string))
|
||||
if !val.IsNil() {
|
||||
if _, exists := result[aggregateTo]; !exists {
|
||||
result[aggregateTo] = make(map[string]any)
|
||||
}
|
||||
result[aggregateTo].(map[string]any)[localKey] = val.Val()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(partials) > 0 {
|
||||
for _, evt := range events {
|
||||
e, _ := evt.(map[string]any)
|
||||
if e["type"] == "partial" {
|
||||
if orderBy, ok := e["order_by"].(string); ok {
|
||||
sort.Slice(partials, func(i, j int) bool {
|
||||
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
|
||||
})
|
||||
}
|
||||
result[e["aggregate_to"].(string)] = partials
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mergedBytes, _ := json.Marshal(result)
|
||||
return gjson.New(mergedBytes).Map(), nil
|
||||
}
|
||||
39
config.yml
39
config.yml
@@ -26,20 +26,45 @@ database:
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
model_gateway:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "model-gateway"
|
||||
prefix: ""
|
||||
role: "master"
|
||||
debug: true
|
||||
dryRun: false
|
||||
charset: "utf8"
|
||||
timezone: "Asia/Shanghai"
|
||||
maxIdle: 5
|
||||
maxOpen: 20
|
||||
maxLifetime: "30s"
|
||||
maxIdleConnTime: "30s"
|
||||
createdAt: "created_at"
|
||||
updatedAt: "updated_at"
|
||||
deletedAt: "deleted_at"
|
||||
timeMaintainDisabled: false
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: 116.204.74.41:6379
|
||||
address: 192.168.3.30:6379
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 116.204.74.41:8500
|
||||
address: 192.168.3.30:8500
|
||||
|
||||
jaeger:
|
||||
addr: 116.204.74.41:4318
|
||||
addr: 192.168.3.30:4318
|
||||
|
||||
# 本地调试用:可选自动执行 worker/cleaner(默认关闭)
|
||||
asynch:
|
||||
queryPending:
|
||||
enabled: false
|
||||
intervalSeconds: 10 # 每10秒轮询一次
|
||||
limit: 10 # 每次查10条
|
||||
worker:
|
||||
enabled: false
|
||||
intervalSeconds: 5
|
||||
@@ -48,11 +73,3 @@ asynch:
|
||||
cleaner:
|
||||
enabled: false
|
||||
intervalSeconds: 30
|
||||
|
||||
modelType:
|
||||
types:
|
||||
1: "推理模型"
|
||||
2: "图片模型"
|
||||
3: "音频模型"
|
||||
4: "向量化模型"
|
||||
5: "全模态模型"
|
||||
|
||||
115
consts/public/public.go
Normal file
115
consts/public/public.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package public
|
||||
|
||||
const (
|
||||
CallModeSync = 0 // 同步调用
|
||||
CallModeAsync = 1 // 异步调用
|
||||
CallModeStream = 2 // 流式调用
|
||||
)
|
||||
|
||||
const (
|
||||
BuildTypePrompt = 1 //提示词构建
|
||||
BuildTypeNode = 2 //节点构建
|
||||
BuildTypeStruct = 3 //结构构建
|
||||
)
|
||||
|
||||
// ModelType 模型类型常量
|
||||
const (
|
||||
ModelTypeInference = 100 // 推理模型
|
||||
|
||||
ModelTypeImage = 200 // 图片模型
|
||||
ImageSubTypeTextToImage = 201 // 图片模型-文生图
|
||||
ImageSubTypeImageToImage = 202 // 图片模型-图生图
|
||||
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
|
||||
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
|
||||
ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图
|
||||
|
||||
ModelTypeAudio = 300 // 音频模型
|
||||
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
|
||||
AudioSubTypeSpeechToText = 302 // 音频模型-音生文
|
||||
AudioSubTypeSpeechToSpeech = 303 // 音频模型-音生音
|
||||
|
||||
ModelTypeVector = 400 // 向量化模型
|
||||
VectorSubTypeEmbedding = 401 // 向量化模型-文本嵌入
|
||||
VectorSubTypeRerank = 402 // 向量化模型-重排序
|
||||
|
||||
ModelTypeOmni = 500 // 全模态模型
|
||||
OmniSubTypeTextImageAudio = 501 // 全模态模型-文图音
|
||||
OmniSubTypeVision = 502 // 全模态模型-视觉理解
|
||||
|
||||
ModelTypeVideo = 600 // 视频模型
|
||||
VideoSubTypeTextToVideo = 601 // 视频模型-文生视频
|
||||
VideoSubTypeImageToVideo = 602 // 视频模型-图生视频
|
||||
VideoSubTypeImageTextToVideo = 603 // 视频模型-图文生视频
|
||||
VideoSubTypeVideoToVideo = 604 // 视频模型-视频生视频
|
||||
)
|
||||
|
||||
// ModelTypeName 模型类型名称映射
|
||||
var ModelTypeName = map[int]string{
|
||||
ModelTypeInference: "推理模型",
|
||||
|
||||
ModelTypeImage: "图片模型",
|
||||
ImageSubTypeTextToImage: "图片模型-文生图",
|
||||
ImageSubTypeImageToImage: "图片模型-图生图",
|
||||
ImageSubTypeImageEdit: "图片模型-图片编辑",
|
||||
ImageSubTypeImageVariation: "图片模型-图片变体",
|
||||
ImageSubTypeImageTextToImage: "图片模型-图文生图",
|
||||
|
||||
ModelTypeAudio: "音频模型",
|
||||
AudioSubTypeTextToSpeech: "音频模型-文生音",
|
||||
AudioSubTypeSpeechToText: "音频模型-音生文",
|
||||
AudioSubTypeSpeechToSpeech: "音频模型-音生音",
|
||||
|
||||
ModelTypeVector: "向量化模型",
|
||||
VectorSubTypeEmbedding: "向量化模型-文本嵌入",
|
||||
VectorSubTypeRerank: "向量化模型-重排序",
|
||||
|
||||
ModelTypeOmni: "全模态模型",
|
||||
OmniSubTypeTextImageAudio: "全模态模型-文图音",
|
||||
OmniSubTypeVision: "全模态模型-视觉理解",
|
||||
|
||||
ModelTypeVideo: "视频模型",
|
||||
VideoSubTypeTextToVideo: "视频模型-文生视频",
|
||||
VideoSubTypeImageToVideo: "视频模型-图生视频",
|
||||
VideoSubTypeImageTextToVideo: "视频模型-图文生视频",
|
||||
VideoSubTypeVideoToVideo: "视频模型-视频生视频",
|
||||
}
|
||||
|
||||
// 运营商常量
|
||||
const (
|
||||
OperatorAliyun = "阿里云百炼"
|
||||
OperatorVolcengine = "火山引擎"
|
||||
OperatorTencent = "腾讯云"
|
||||
OperatorHuawei = "华为云"
|
||||
OperatorBaidu = "百度智能云"
|
||||
OperatorOpenAI = "OpenAI"
|
||||
OperatorAzure = "Azure OpenAI"
|
||||
OperatorAWS = "AWS Bedrock"
|
||||
OperatorGoogle = "Google Cloud"
|
||||
OperatorDeepSeek = "DeepSeek"
|
||||
OperatorMoonshot = "Moonshot"
|
||||
OperatorZhipu = "智谱AI"
|
||||
OperatorBaichuan = "百川智能"
|
||||
OperatorMinimax = "MiniMax"
|
||||
OperatorXunfei = "科大讯飞"
|
||||
OperatorOthers = "其他"
|
||||
)
|
||||
|
||||
// OperatorList 运营商列表(供前端下拉框使用)
|
||||
var OperatorList = []string{
|
||||
OperatorAliyun,
|
||||
OperatorVolcengine,
|
||||
OperatorTencent,
|
||||
OperatorHuawei,
|
||||
OperatorBaidu,
|
||||
OperatorOpenAI,
|
||||
OperatorAzure,
|
||||
OperatorAWS,
|
||||
OperatorGoogle,
|
||||
OperatorDeepSeek,
|
||||
OperatorMoonshot,
|
||||
OperatorZhipu,
|
||||
OperatorBaichuan,
|
||||
OperatorMinimax,
|
||||
OperatorXunfei,
|
||||
OperatorOthers,
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
package public
|
||||
|
||||
const (
|
||||
DbNameModelGateway = "model_gateway" //数据库名称
|
||||
)
|
||||
|
||||
const (
|
||||
TableNameModel = "asynch_models" // 模型表
|
||||
TableNameTask = "asynch_task" // 任务表
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
package controller
|
||||
@@ -2,12 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
modelService "model-gateway/service/model"
|
||||
"model-gateway/service/queue"
|
||||
)
|
||||
|
||||
type model struct{}
|
||||
@@ -17,71 +14,53 @@ var Model = new(model)
|
||||
|
||||
// CreateModel 添加配置
|
||||
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
return service.Model.Create(ctx, req)
|
||||
return modelService.Model.Create(ctx, req)
|
||||
}
|
||||
|
||||
// UpdateModel 更改配置
|
||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Model.Update(ctx, req)
|
||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
||||
err = modelService.Model.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteModel 删除配置
|
||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Model.Delete(ctx, req.ID)
|
||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
||||
err = modelService.Model.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetModel 获取配置详情(按 modelName)
|
||||
// GetModel 获取配置详情
|
||||
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
||||
model, err := service.Model.Get(ctx, req.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &dto.GetModelRes{Model: model}, nil
|
||||
return modelService.Model.Get(ctx, req)
|
||||
}
|
||||
|
||||
// ListModel 配置列表
|
||||
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
list, total, err := service.Model.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListModelRes{
|
||||
List: list,
|
||||
Total: total,
|
||||
}, nil
|
||||
return modelService.Model.List(ctx, req)
|
||||
}
|
||||
|
||||
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
||||
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
windowSeconds := 3600
|
||||
if req != nil && req.WindowSeconds > 0 {
|
||||
windowSeconds = req.WindowSeconds
|
||||
}
|
||||
list, err := service.AutoTune(ctx, windowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.AutoTuneRes{List: list}, nil
|
||||
return queue.AutoTune(ctx, req)
|
||||
}
|
||||
|
||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res dto.TypeItem, err error) {
|
||||
modelType := service.GetModelTypesFromConfig(ctx)
|
||||
res.Type = modelType
|
||||
return res, nil
|
||||
// ListType 模型类型列表
|
||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
||||
return modelService.GetModelTypesFromConfig()
|
||||
}
|
||||
|
||||
// ListOperator 运营商列表
|
||||
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
|
||||
return modelService.GetOperatorList()
|
||||
}
|
||||
|
||||
// UpdateChatModel 更新是否为聊天模型
|
||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Model.UpdateChatModel(ctx, req)
|
||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
||||
err = modelService.Model.UpdateChatModel(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetIsChatModel 获取是否为聊天模型
|
||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *entity.AsynchModel, err error) {
|
||||
return service.Model.GetIsChatModel(ctx)
|
||||
// GetIsChatModel 获取当前会话模型
|
||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
||||
return modelService.Model.GetIsChatModel(ctx)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
statService "model-gateway/service/stat"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service"
|
||||
)
|
||||
|
||||
type stat struct{}
|
||||
@@ -14,5 +14,5 @@ var Stat = new(stat)
|
||||
|
||||
// ListModelStat 统计列表
|
||||
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
||||
return service.Stat.List(ctx, req)
|
||||
return statService.Stat.List(ctx, req)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/service/job"
|
||||
taskService "model-gateway/service/task"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service"
|
||||
)
|
||||
|
||||
type task struct{}
|
||||
@@ -14,44 +15,35 @@ var Task = new(task)
|
||||
|
||||
// CreateTask 根据 modelName 创建异步任务,返回 taskId
|
||||
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
return service.Task.Create(ctx, req)
|
||||
return taskService.Task.Create(ctx, req)
|
||||
}
|
||||
|
||||
// ModelTaskCallback 接收模型异步任务的回调通知
|
||||
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
|
||||
return taskService.Task.ModelTaskCallback(ctx, req)
|
||||
}
|
||||
|
||||
// QueryPendingTasks 批量轮询进行中的异步任务
|
||||
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
|
||||
return taskService.Task.QueryPendingTasks(ctx, req)
|
||||
}
|
||||
|
||||
// GetTaskResult 获取任务结果(只返回 oss 地址 + state)
|
||||
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
|
||||
return service.Task.GetResult(ctx, req.TaskID)
|
||||
return taskService.Task.GetResult(ctx, req.TaskID)
|
||||
}
|
||||
|
||||
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
|
||||
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
return service.Task.GetBatch(ctx, req)
|
||||
return taskService.Task.GetBatch(ctx, req)
|
||||
}
|
||||
|
||||
// ListTask 任务列表分页查询
|
||||
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
return service.Task.List(ctx, req)
|
||||
}
|
||||
|
||||
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
||||
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
batchSize, goroutines := 10, 1
|
||||
if req != nil {
|
||||
if req.BatchSize > 0 {
|
||||
batchSize = req.BatchSize
|
||||
}
|
||||
if req.Goroutines > 0 {
|
||||
goroutines = req.Goroutines
|
||||
}
|
||||
}
|
||||
n, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.RunWorkRes{Claimed: n}, nil
|
||||
return taskService.Task.List(ctx, req)
|
||||
}
|
||||
|
||||
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
||||
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
||||
service.Cleaner.RunOnce(ctx)
|
||||
return &dto.CleanWorkRes{Ok: true}, nil
|
||||
return job.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
|
||||
250
dao/model_dao.go
250
dao/model_dao.go
@@ -2,14 +2,12 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"strconv"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -18,159 +16,148 @@ var Model = &modelDao{}
|
||||
|
||||
type modelDao struct{}
|
||||
|
||||
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
|
||||
asyncModel := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &asyncModel)
|
||||
// Insert 插入
|
||||
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
|
||||
m := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
// Update 更新
|
||||
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, m.ID).
|
||||
Data(m).
|
||||
Data(&req).
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
// Delete 删除
|
||||
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.ModelName, modelName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
|
||||
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
|
||||
// Get 获取模型
|
||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.Id, req.ID).Count()
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt)
|
||||
if modelNameLike != "" {
|
||||
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
if modelType != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
|
||||
}
|
||||
if isPrivate != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
//// Get 按ID获取(带租户隔离,只查当前租户)
|
||||
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
// var whereCondition strings.Builder
|
||||
// var queryParams []interface{}
|
||||
// if !g.IsEmpty(req.Id) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Id))
|
||||
// queryParams = append(queryParams, req.Id)
|
||||
// }
|
||||
// if !g.IsEmpty(req.Creator) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.Creator))
|
||||
// queryParams = append(queryParams, req.Creator)
|
||||
// }
|
||||
// if !g.IsEmpty(req.IsChatModel) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.IsChatModel))
|
||||
// queryParams = append(queryParams, req.IsChatModel)
|
||||
// }
|
||||
// if !g.IsEmpty(req.ModelName) {
|
||||
// whereCondition.WriteString(fmt.Sprintf(" AND %s = (?) ", entity.AsynchModelCol.ModelName))
|
||||
// queryParams = append(queryParams, req.ModelName)
|
||||
// }
|
||||
// // 完整 SQL
|
||||
// sql := `SELECT * FROM "asynch_models" WHERE "deleted_at" IS NULL` + whereCondition.String()
|
||||
// r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, queryParams...)
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
// var i []*entity.AsynchModel
|
||||
// if err = r.Structs(&i); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// for _, item := range i {
|
||||
// m = item
|
||||
// }
|
||||
// return
|
||||
//}
|
||||
|
||||
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
|
||||
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
// 构建 Where 条件
|
||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
|
||||
args := []any{creator}
|
||||
|
||||
if modelNameLike != "" {
|
||||
whereSQL += " AND model_name LIKE ?"
|
||||
args = append(args, "%"+modelNameLike+"%")
|
||||
}
|
||||
|
||||
// 查总数
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
|
||||
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if len(countResult) > 0 {
|
||||
total = gconv.Int64(countResult[0]["count"])
|
||||
}
|
||||
|
||||
// 查列表
|
||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
offset := (pageNum - 1) * pageSize
|
||||
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
||||
}
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
// 基础 SQL
|
||||
sql := `
|
||||
SELECT DISTINCT ON (model_name) *
|
||||
FROM asynch_models
|
||||
WHERE deleted_at IS NULL
|
||||
AND (? = '' OR model_name LIKE ?)
|
||||
AND (? = 0 OR model_type = ?)
|
||||
`
|
||||
args := []any{
|
||||
req.ModelName, "%" + req.ModelName + "%",
|
||||
req.ModelType, req.ModelType,
|
||||
}
|
||||
|
||||
// modelType: 传 6 模糊匹配 6%
|
||||
if req.ModelType > 0 {
|
||||
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
|
||||
sql += ` AND model_type::text LIKE ? `
|
||||
args = append(args, prefix+"%")
|
||||
}
|
||||
|
||||
if !g.IsEmpty(req.IsPrivate) {
|
||||
sql += ` AND is_private = ? `
|
||||
args = append(args, req.IsPrivate)
|
||||
}
|
||||
|
||||
if req.IsOwner != nil && *req.IsOwner == 0 {
|
||||
sql += ` AND creator = ? AND is_owner = ? `
|
||||
args = append(args, req.Creator)
|
||||
args = append(args, req.IsOwner)
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
|
||||
} else if req.Enabled != nil && *req.Enabled == 0 {
|
||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=0 `
|
||||
} else {
|
||||
sql += ` AND creator = ? AND is_owner = ? `
|
||||
}
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
} else if req.IsOwner != nil && *req.IsOwner == 1 {
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
|
||||
@@ -179,14 +166,12 @@ WHERE deleted_at IS NULL
|
||||
} else {
|
||||
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
|
||||
}
|
||||
args = append(args, req.Creator)
|
||||
args = append(args, req.IsOwner)
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
}
|
||||
|
||||
// 最后拼接排序
|
||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -200,33 +185,24 @@ WHERE deleted_at IS NULL
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
||||
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
|
||||
One()
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
|
||||
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt).
|
||||
All()
|
||||
if err != nil {
|
||||
var list []*entity.AsynchModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
)
|
||||
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.AsynchModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
@@ -6,15 +6,23 @@ import (
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
type opLogDao struct{}
|
||||
|
||||
var OpLog = &opLogDao{}
|
||||
|
||||
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
|
||||
// Insert 插入
|
||||
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
|
||||
m := new(entity.LogsModelOp)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name)
|
||||
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
||||
public.TableNameStat, public.TableNameStat,
|
||||
)
|
||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
216
dao/task_dao.go
216
dao/task_dao.go
@@ -2,13 +2,10 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
@@ -18,40 +15,60 @@ var Task = &taskDao{}
|
||||
|
||||
type taskDao struct{}
|
||||
|
||||
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
|
||||
// Insert 插入
|
||||
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
|
||||
m := new(entity.AsynchTask)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||||
One()
|
||||
// Update 更新
|
||||
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Data(&req).
|
||||
Where(entity.AsynchTaskCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Get 获取
|
||||
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Struct(&t)
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
|
||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
err = r.Structs(&m)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,7 +79,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Where(entity.AsynchTaskCol.State, 2).
|
||||
Data(data).
|
||||
@@ -70,76 +87,9 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 1,
|
||||
entity.AsynchTaskCol.StartedAt: now,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 2,
|
||||
entity.AsynchTaskCol.OssFile: ossFile,
|
||||
entity.AsynchTaskCol.FileType: fileType,
|
||||
entity.AsynchTaskCol.FileSize: fileSize,
|
||||
entity.AsynchTaskCol.ErrorMsg: "",
|
||||
entity.AsynchTaskCol.FinishedAt: now,
|
||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 3,
|
||||
entity.AsynchTaskCol.ErrorMsg: errorMsg,
|
||||
entity.AsynchTaskCol.FinishedAt: now,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
|
||||
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
|
||||
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.ModelName, modelName).
|
||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||||
Count()
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
||||
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||||
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||||
if modelNameLike != "" {
|
||||
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
@@ -162,89 +112,13 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike
|
||||
return
|
||||
}
|
||||
|
||||
// ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1)
|
||||
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
|
||||
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0
|
||||
ORDER BY created_at ASC
|
||||
LIMIT %d
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
batchSize,
|
||||
)
|
||||
r, err := tx.GetAll(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
tasks = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Structs(&tasks); err != nil {
|
||||
return err
|
||||
}
|
||||
// 更新为 running
|
||||
now := time.Now()
|
||||
for _, t := range tasks {
|
||||
// tx.Model 不走 gfdb Hook,这里手动更新必要字段
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, t.Id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ListExpiredSuccess 获取已成功且过期的任务
|
||||
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.State, 2).
|
||||
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
|
||||
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
|
||||
// GetPendingAsyncTasks 获取进行中的异步任务
|
||||
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where("state", 1).
|
||||
Where("deleted_at IS NULL").
|
||||
Limit(limit).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// ListTimeoutTasks 获取超时的排队/执行中任务
|
||||
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
deadline := gtime.New(time.Now().Add(-timeout))
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||||
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
|
||||
Limit(limit).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// DebugPing 用于启动时检测数据库连通性(可选)
|
||||
func (d *taskDao) DebugPing(ctx context.Context) error {
|
||||
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
|
||||
return err
|
||||
Scan(&tasks)
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
@@ -8,33 +8,58 @@ import (
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
// ClaimPendingGlobal 后台任务使用:全局抢占 pending 任务(不加 tenant 过滤)
|
||||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||||
// ======================== 查询辅助 ========================
|
||||
|
||||
// taskColumns 查询用的公共字段
|
||||
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
|
||||
|
||||
// ======================== 事务抢占 ========================
|
||||
|
||||
// claimTasks 事务内抢占任务并更新 state=1
|
||||
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
|
||||
r, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var task entity.AsynchTask
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tasks = []*entity.AsynchTask{&task}
|
||||
return nil
|
||||
})
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingGlobal 批量抢占 pending 任务
|
||||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0
|
||||
ORDER BY enqueue_at ASC
|
||||
LIMIT %d
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
batchSize,
|
||||
)
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
|
||||
r, err := tx.GetAll(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
tasks = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Structs(&tasks); err != nil {
|
||||
@@ -42,245 +67,148 @@ func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) (tasks
|
||||
}
|
||||
now := time.Now()
|
||||
for _, t := range tasks {
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, t.Id,
|
||||
)
|
||||
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 定向抢占单个 pending 任务(不加 tenant 过滤)
|
||||
// 用于 createTask 创建成功后立即异步尝试执行当前任务,避免只依赖后续 runWork 扫描队列。
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (task *entity.AsynchTask, err error) {
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
|
||||
if taskID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, input_ref, request_payload, phase, tmp_file
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0 AND task_id = ?
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
)
|
||||
r, err := tx.GetOne(sql, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
task = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, task.Id,
|
||||
)
|
||||
return err
|
||||
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
|
||||
if err != nil || len(tasks) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return tasks[0], nil
|
||||
}
|
||||
|
||||
// ======================== 更新辅助 ========================
|
||||
|
||||
func execSQL(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// updateTask 通用更新
|
||||
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSuccessGlobal 更新任务成功
|
||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
||||
State: 2,
|
||||
OssFile: t.OssFile,
|
||||
FileType: t.FileType,
|
||||
TextResult: t.TextResult,
|
||||
FileSize: t.FileSize,
|
||||
ErrorMsg: "",
|
||||
FinishedAt: gtime.Now(),
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
ExpendTokens: t.ExpendTokens,
|
||||
DurationSeconds: t.DurationSeconds,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, id int64, ossFile, fileType, textResult string, fileSize int64, expireAt *gtime.Time, expendTokens int) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s
|
||||
SET state=2,
|
||||
oss_file=?,
|
||||
file_type=?,
|
||||
text_result=?,
|
||||
expend_tokens=?,
|
||||
file_size=?,
|
||||
error_msg='',
|
||||
finished_at=?,
|
||||
duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT,
|
||||
expire_at=NULL,
|
||||
phase=0,
|
||||
tmp_file='',
|
||||
updated_at=?
|
||||
WHERE id=?`, public.TableNameTask),
|
||||
ossFile, fileType, textResult, expendTokens, fileSize, now, now, now, id,
|
||||
)
|
||||
return err
|
||||
// UpdateFailedGlobal 模型调用失败
|
||||
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
|
||||
return updateTask(ctx, t.Id, entity.AsynchTask{
|
||||
State: 3,
|
||||
ErrorMsg: t.ErrorMsg,
|
||||
FinishedAt: gtime.Now(),
|
||||
Phase: 0,
|
||||
TmpFile: "",
|
||||
TextResult: t.TextResult,
|
||||
DurationSeconds: t.DurationSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s
|
||||
SET state=3,
|
||||
error_msg=?,
|
||||
finished_at=?,
|
||||
duration_seconds=EXTRACT(EPOCH FROM (? - created_at))::BIGINT,
|
||||
phase=0,
|
||||
tmp_file='',
|
||||
updated_at=?
|
||||
WHERE id=?`, public.TableNameTask),
|
||||
errorMsg, now, now, now, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateFailedKeepTmpGlobal OSS 上传失败:保留 phase/tmp_file,下一轮仅重试 OSS 上传
|
||||
// UpdateFailedKeepTmpGlobal OSS 上传失败
|
||||
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
|
||||
now := gtime.Now()
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
errorMsg, now, now, id,
|
||||
)
|
||||
return err
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
|
||||
}
|
||||
|
||||
// UpdateTmpAfterModelGlobal 模型调用成功后,写入临时文件路径并标记 phase=1
|
||||
// UpdateTmpAfterModelGlobal 写临时文件
|
||||
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask),
|
||||
tmpFile, id,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
|
||||
taskID,
|
||||
)
|
||||
return err
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
|
||||
}
|
||||
|
||||
// RollbackToPendingGlobal 回滚
|
||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
return err
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// ListExpiredDownloadedGlobal 获取已下载(state=4)且过期的任务,用于清理
|
||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
|
||||
gtime.Now(), limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
// IncRetryCountGlobal 重试计数+1
|
||||
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// ListFailedRetryableGlobal 获取失败(state=3)且仍可重试的任务
|
||||
// retry_count 不含首次执行;retry_times 表示失败后最多再重试 N 次
|
||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*,
|
||||
m.retry_queue_max_seconds AS retry_queue_max_seconds
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state = 3
|
||||
AND t.retry_count < m.retry_times
|
||||
ORDER BY t.updated_at ASC
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// RequeueForRetryGlobal 将任务重新入队(state=0),并将 retry_count +1
|
||||
// enqueueAt 用于控制重试任务在队列中的位置:
|
||||
// - enqueueAt 越早,越靠前(ClaimPendingGlobal 按 enqueue_at ASC 抢占)
|
||||
// RequeueForRetryGlobal 重新入队
|
||||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
|
||||
enqueueAt, id,
|
||||
)
|
||||
return err
|
||||
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
|
||||
}
|
||||
|
||||
// ListFailedExhaustedGlobal 获取失败(state=3)且超过重试次数的任务,用于硬删除
|
||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state = 3
|
||||
AND t.retry_count >= m.retry_times
|
||||
ORDER BY t.updated_at ASC
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
// ======================== 列表查询 ========================
|
||||
|
||||
// ListExpiredDownloadedGlobal
|
||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// HardDeleteByIDGlobal 硬删除任务记录
|
||||
// ListFailedRetryableGlobal
|
||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// ListFailedExhaustedGlobal
|
||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// ListTimeoutTasksGlobal
|
||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
|
||||
}
|
||||
|
||||
// HardDeleteByIDGlobal
|
||||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
return err
|
||||
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
|
||||
}
|
||||
|
||||
// ListTimeoutTasksGlobal 根据模型配置 expected_seconds 判定超时任务:
|
||||
// - state in (0,1)
|
||||
// - 模型 expected_seconds > 0
|
||||
// - now - created_at >= expected_seconds
|
||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 200
|
||||
}
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
fmt.Sprintf(`
|
||||
SELECT t.*
|
||||
FROM %s t
|
||||
JOIN %s m
|
||||
ON t.tenant_id = m.tenant_id
|
||||
AND t.model_name = m.model_name
|
||||
WHERE t.deleted_at IS NULL
|
||||
AND t.state IN (0,1)
|
||||
AND m.expected_seconds > 0
|
||||
AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval)
|
||||
LIMIT ?`, public.TableNameTask, public.TableNameModel),
|
||||
limit,
|
||||
)
|
||||
// ======================== 内部辅助 ========================
|
||||
|
||||
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var list []*entity.AsynchTask
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
return list, err
|
||||
}
|
||||
|
||||
func clampLimit(limit, defaultVal int) int {
|
||||
if limit <= 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
// UpdateColumns 更新指定字段(结构体版)
|
||||
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
34
go.mod
34
go.mod
@@ -1,22 +1,14 @@
|
||||
module model-gateway
|
||||
|
||||
go 1.26.0
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
gitea.com/red-future/common v0.0.19
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0
|
||||
github.com/gogf/gf/v2 v2.10.0
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
|
||||
github.com/gogf/gf/v2 v2.10.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/tidwall/gjson v1.14.2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
github.com/tidwall/gjson v1.19.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -57,7 +49,7 @@ require (
|
||||
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
|
||||
github.com/hashicorp/golang-lru v1.0.2 // indirect
|
||||
github.com/hashicorp/serf v0.10.1 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/magiconair/properties v1.8.10 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
@@ -69,11 +61,14 @@ require (
|
||||
github.com/olekukonko/ll v0.0.9 // indirect
|
||||
github.com/olekukonko/tablewriter v1.1.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.4.0 // indirect
|
||||
go.opencensus.io v0.23.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
@@ -85,9 +80,10 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/appengine v1.6.7 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
|
||||
52
go.sum
52
go.sum
@@ -1,6 +1,6 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
gitea.com/red-future/common v0.0.19 h1:9/WrfCFUCeFUYwuhBYF+JOQi5F5xuOy+gVnf2ZvHZu4=
|
||||
gitea.com/red-future/common v0.0.19/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23 h1:xieoA00iKOCDm5SO9iXn+cSyMKBAlZwI0fuEVPWrHLg=
|
||||
gitea.redpowerfuture.com/red-future/common v0.0.23/go.mod h1:50U1Xi+Ie56z09S5LQbZvaken0Mxv3OeS9LgR7U/ZRY=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -77,16 +77,16 @@ github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ4
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 h1:39+jbTenm7KBj4hO2C8ANAxVHpX/7OuRDs1VcGC9ylA=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0/go.mod h1:B0s0fVzn0W220E8UTpSGzrrGKsop5KcB90twBeLCiz0=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0 h1:N/F9CuDdUZLoM1nVRqrDE/33pDZuhVxpNY4wYdeIaBs=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.0/go.mod h1:x6uoJGfZOtirIRQls8xUlYzC6f7T/eULPUa9er368X0=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2 h1:u8EpP24GkprogROnJ7htMov9Fc66pTP1eVYrWxiCYOs=
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2/go.mod h1:GmvM3r8GVByVMi4RD2+MCs5+CfxVXPMeT8mVDkAaAXE=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2 h1:iTQegT+lEg/wDKvj2mi3W1wrdrwFarjokf88EXVVgu4=
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2/go.mod h1:ZRw3GNz5cq4uYrW4TPSVyrYWaoqzujKdWro/AOcGBaE=
|
||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 h1:eUqwJ/qNH8lJ6yssiqskazgp1ACQuNU6zXlLOZVuXTQ=
|
||||
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5/go.mod h1:sjQyMry9+0POYZCA6lHXBxO77WoNKkruJpRB4xKqk5k=
|
||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 h1:tHUEZYB5GTqEYYVDYnlGobf1xISARKDE4KHVlgjwTec=
|
||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5/go.mod h1:cfzTn2HS9RDX8f5pUVkbGxUWcSosouqfNQ1G6cY0V88=
|
||||
github.com/gogf/gf/v2 v2.10.0 h1:rzDROlyqGMe/eM6dCalSR8dZOuMIdLhmxKSH1DGhbFs=
|
||||
github.com/gogf/gf/v2 v2.10.0/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
||||
github.com/gogf/gf/v2 v2.10.2 h1:46IO0Uc8e85/FqdftJFskfDejJLBL0JBnGS5qOftUu8=
|
||||
github.com/gogf/gf/v2 v2.10.2/go.mod h1:Svl1N+E8G/QshU2DUbh/3J/AJauqCgUnxHurXWR4Qx0=
|
||||
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
@@ -185,8 +185,8 @@ github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/u
|
||||
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
@@ -288,14 +288,12 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
||||
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
@@ -344,8 +342,8 @@ golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvx
|
||||
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -361,8 +359,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210410081132-afb366fc7cd1/go.mod h1:9tjilg8BloeKEkVJvy7fQ90B1CfIiPueXVOjqfkSzI8=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -371,8 +369,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -397,15 +395,15 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
@@ -415,8 +413,8 @@ golang.org/x/tools v0.0.0-20190907020128-2ca718005c18/go.mod h1:b+2E5dAYhXwXZwtn
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -426,6 +424,8 @@ gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
|
||||
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
|
||||
|
||||
60
main.go
60
main.go
@@ -2,17 +2,19 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/job"
|
||||
"model-gateway/service/task"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"model-gateway/controller"
|
||||
"model-gateway/service"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
_ "gitea.com/red-future/common/swagger"
|
||||
"gitea.redpowerfuture.com/red-future/common/http"
|
||||
"gitea.redpowerfuture.com/red-future/common/jaeger"
|
||||
_ "gitea.redpowerfuture.com/red-future/common/swagger"
|
||||
_ "github.com/gogf/gf/contrib/drivers/pgsql/v2"
|
||||
_ "github.com/gogf/gf/contrib/nosql/redis/v2"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -33,42 +35,18 @@ func main() {
|
||||
// 本地调试:可选自动触发 worker/cleaner(由配置文件控制)
|
||||
startAutoRunner(ctx)
|
||||
|
||||
// 监听退出信号,确保 Ctrl+C 能完整退出(停止 worker/cleaner 并关闭 http server)
|
||||
// 监听退出信号,确保 Ctrl+C 能完整退出(停止 worker/cleaner 并关闭 gateway server)
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
g.Log().Infof(ctx, "[main] 收到退出信号,开始优雅退出...")
|
||||
cancel()
|
||||
// 关闭 http server(RouteRegister 内部是 go Httpserver.Run() 启动的)
|
||||
// 关闭 gateway server(RouteRegister 内部是 go Httpserver.Run() 启动的)
|
||||
_ = http.Httpserver.Shutdown()
|
||||
}
|
||||
|
||||
func startAutoRunner(ctx context.Context) {
|
||||
// worker
|
||||
if g.Cfg().MustGet(ctx, "asynch.worker.enabled").Bool() {
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
|
||||
if interval <= 0 {
|
||||
interval = 5
|
||||
}
|
||||
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize").Int()
|
||||
goroutines := g.Cfg().MustGet(ctx, "asynch.worker.goroutines").Int()
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines); err != nil {
|
||||
g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// cleaner
|
||||
if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() {
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int()
|
||||
@@ -83,7 +61,27 @@ func startAutoRunner(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
service.Cleaner.RunOnce(ctx)
|
||||
_, _ = job.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// queryPending
|
||||
if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() {
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int()
|
||||
limit := g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
|
||||
g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1,36 +1,44 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"gitea.com/red-future/common/beans"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// CreateModelReq 添加模型配置
|
||||
type CreateModelReq struct {
|
||||
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 http(s)://host:port)"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||
Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明"`
|
||||
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
||||
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
||||
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
||||
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
||||
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
||||
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||
}
|
||||
|
||||
type CreateModelRes struct {
|
||||
@@ -38,48 +46,64 @@ type CreateModelRes struct {
|
||||
}
|
||||
|
||||
type UpdateModelReq struct {
|
||||
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
|
||||
Form any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"`
|
||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射(可选更新)"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
|
||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST"`
|
||||
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化:0-私有 1-公共"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-停用 1-启用"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型:0-否 1-是"`
|
||||
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式:0-同步 1-异步 2-流式"`
|
||||
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者:0-否 1-是"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
|
||||
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
|
||||
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
|
||||
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||
}
|
||||
|
||||
type UpdateModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// DeleteModelReq 删除模型配置
|
||||
type DeleteModelReq struct {
|
||||
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
|
||||
ID string `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
}
|
||||
|
||||
type DeleteModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// GetModelReq 获取模型配置详情
|
||||
type GetModelReq struct {
|
||||
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
Creator string `p:"creator" json:"creator" dc:"创建人"`
|
||||
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
|
||||
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
|
||||
Creator string `p:"creator" json:"creator" dc:"创建人"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"`
|
||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||
}
|
||||
|
||||
type GetModelRes struct {
|
||||
Model any `json:"model" dc:"模型配置详情"`
|
||||
Model *entity.AsynchModel `json:"model" dc:"模型配置详情"`
|
||||
}
|
||||
|
||||
// ListModelReq 配置列表
|
||||
@@ -124,11 +148,43 @@ type TypeItem struct {
|
||||
Type map[int]string `json:"type" dc:"模型类型ID到名称的映射"`
|
||||
}
|
||||
|
||||
type ListOperatorReq struct {
|
||||
g.Meta `path:"/listOperator" method:"get" tags:"模型管理" summary:"获取运营商列表" dc:"获取运营商列表"`
|
||||
}
|
||||
|
||||
type ListOperatorRes struct {
|
||||
List []string `json:"list" dc:"运营商名称到ID的映射"`
|
||||
}
|
||||
|
||||
type UpdateChatModelReq struct {
|
||||
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
|
||||
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
|
||||
}
|
||||
type UpdateChatModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"模型ID"`
|
||||
}
|
||||
|
||||
type GetIsChatModelReq struct {
|
||||
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
|
||||
}
|
||||
|
||||
type GetIsChatModelRes struct {
|
||||
Model any `json:"model" dc:"模型详情"`
|
||||
}
|
||||
|
||||
// NodeFormField 节点表单
|
||||
type NodeFormField struct {
|
||||
Value any `json:"value" dc:"字段值"`
|
||||
Field string `json:"field" dc:"字段标识"`
|
||||
Label string `json:"label" dc:"字段标签"`
|
||||
Type string `json:"type" dc:"字段类型"`
|
||||
Required bool `json:"required" dc:"是否必填"`
|
||||
Default any `json:"default,omitempty" dc:"默认值"`
|
||||
Options []SelectOption `json:"options" dc:"下拉选项列表"`
|
||||
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
|
||||
}
|
||||
|
||||
type SelectOption struct {
|
||||
Label string `json:"label" dc:"选项标签"`
|
||||
Value string `json:"value" dc:"选项值"`
|
||||
}
|
||||
|
||||
@@ -5,22 +5,55 @@ import "github.com/gogf/gf/v2/frame/g"
|
||||
// CreateTaskReq 创建异步任务
|
||||
type CreateTaskReq struct {
|
||||
g.Meta `path:"/createTask" method:"post" tags:"任务管理" summary:"创建异步任务" dc:"创建异步任务并返回任务ID;创建成功后会立即异步尝试执行当前任务,执行成功后按回调配置触发钩子"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
|
||||
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
|
||||
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用(如OSS/文件引用等)"`
|
||||
RequestPayload any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
|
||||
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
|
||||
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用(如OSS/文件引用等)"`
|
||||
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
|
||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
||||
BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"`
|
||||
}
|
||||
|
||||
type CreateTaskRes struct {
|
||||
TaskID string `json:"taskId" dc:"任务ID"`
|
||||
}
|
||||
|
||||
type ModelTaskCallbackReq struct {
|
||||
g.Meta `path:"/modelCallback" method:"post" tags:"异步任务" summary:"模型任务回调通知"`
|
||||
TaskID string `json:"id" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"queued/running/succeeded/failed/expired"`
|
||||
Content map[string]any `json:"content,omitempty" dc:"任务结果内容"`
|
||||
Usage map[string]any `json:"usage,omitempty" dc:"token用量"`
|
||||
}
|
||||
|
||||
type ModelTaskCallbackRes struct {
|
||||
Success bool `json:"success" dc:"是否接收成功"`
|
||||
}
|
||||
|
||||
// QueryPendingTasksReq 批量轮询请求
|
||||
type QueryPendingTasksReq struct {
|
||||
g.Meta `path:"/queryPending" method:"get" tags:"异步任务" summary:"批量轮询进行中的任务"`
|
||||
Limit int `p:"limit" json:"limit" dc:"查询数量,默认10"`
|
||||
}
|
||||
|
||||
// QueryPendingTasksRes 批量轮询响应
|
||||
type QueryPendingTasksRes struct {
|
||||
Total int `json:"total" dc:"本次查询数量"`
|
||||
Results []QueryTaskItem `json:"results" dc:"查询结果列表"`
|
||||
}
|
||||
|
||||
// QueryTaskItem 单个任务查询结果
|
||||
type QueryTaskItem struct {
|
||||
TaskID string `json:"taskId" dc:"任务ID"`
|
||||
Status string `json:"status" dc:"任务状态"`
|
||||
Content map[string]any `json:"content,omitempty" dc:"结果内容"`
|
||||
Usage map[string]any `json:"usage,omitempty" dc:"token用量"`
|
||||
}
|
||||
|
||||
// GetTaskResultReq 获取结果(只返回 oss 地址)
|
||||
type GetTaskResultReq struct {
|
||||
g.Meta `path:"/getTaskResult" method:"get" tags:"任务管理" summary:"获取任务结果" dc:"根据任务ID获取结果(只返回OSS地址)"`
|
||||
TaskID string `p:"taskId" json:"taskId" v:"required#taskId不能为空" dc:"任务ID"`
|
||||
TaskID string `p:"taskId" json:"taskId" v:"required#taskwId不能为空" dc:"任务ID"`
|
||||
}
|
||||
|
||||
type GetTaskResultRes struct {
|
||||
|
||||
@@ -1,88 +1,103 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
import "gitea.redpowerfuture.com/red-future/common/beans"
|
||||
|
||||
type asynchModelCol struct {
|
||||
beans.SQLBaseCol
|
||||
ModelName string
|
||||
ModelType string
|
||||
BaseURL string
|
||||
HttpMethod string
|
||||
HeadMsg string
|
||||
FormJSON string
|
||||
RequestMapping string
|
||||
ResponseMapping string
|
||||
ResponseBody string
|
||||
TokenMapping string
|
||||
Prompt string
|
||||
IsPrivate string
|
||||
IsChatModel string
|
||||
ApiKey string
|
||||
Enabled string
|
||||
MaxConcurrency string
|
||||
QueueLimit string
|
||||
TimeoutSeconds string
|
||||
ExpectedSeconds string
|
||||
RetryTimes string
|
||||
RetryQueueMaxSecs string
|
||||
AutoCleanSeconds string
|
||||
Remark string
|
||||
IsOwner string
|
||||
ModelName string
|
||||
ModelType string
|
||||
BaseURL string
|
||||
HttpMethod string
|
||||
HeadMsg string
|
||||
FormJSON string
|
||||
RequestMapping string
|
||||
ResponseMapping string
|
||||
ResponseBody string
|
||||
ResponseTokenField string
|
||||
RequiredFields string
|
||||
IsPrivate string
|
||||
IsChatModel string
|
||||
CallMode string
|
||||
ApiKey string
|
||||
Enabled string
|
||||
MaxConcurrency string
|
||||
TimeoutSeconds string
|
||||
RetryTimes string
|
||||
AutoCleanSeconds string
|
||||
IsOwner string
|
||||
OperatorName string
|
||||
TokenConfig string
|
||||
ExtendMapping string
|
||||
QueryConfig string
|
||||
StreamConfig string
|
||||
FirstFrame string
|
||||
LastFrame string
|
||||
CallbackUrl string
|
||||
}
|
||||
|
||||
var AsynchModelCol = asynchModelCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelName: "model_name",
|
||||
ModelType: "model_type",
|
||||
BaseURL: "base_url",
|
||||
HttpMethod: "http_method",
|
||||
HeadMsg: "head_msg",
|
||||
FormJSON: "form_json",
|
||||
RequestMapping: "request_mapping",
|
||||
ResponseMapping: "response_mapping",
|
||||
ResponseBody: "response_body",
|
||||
TokenMapping: "token_mapping",
|
||||
Prompt: "prompt",
|
||||
IsPrivate: "is_private",
|
||||
IsChatModel: "is_chat_model",
|
||||
ApiKey: "api_key",
|
||||
Enabled: "enabled",
|
||||
MaxConcurrency: "max_concurrency",
|
||||
QueueLimit: "queue_limit",
|
||||
TimeoutSeconds: "timeout_seconds",
|
||||
ExpectedSeconds: "expected_seconds",
|
||||
RetryTimes: "retry_times",
|
||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
Remark: "remark",
|
||||
IsOwner: "is_owner",
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelName: "model_name",
|
||||
ModelType: "model_type",
|
||||
BaseURL: "base_url",
|
||||
HttpMethod: "http_method",
|
||||
HeadMsg: "head_msg",
|
||||
FormJSON: "form_json",
|
||||
RequestMapping: "request_mapping",
|
||||
ResponseMapping: "response_mapping",
|
||||
ResponseBody: "response_body",
|
||||
ResponseTokenField: "response_token_field",
|
||||
RequiredFields: "required_fields",
|
||||
IsPrivate: "is_private",
|
||||
IsChatModel: "is_chat_model",
|
||||
CallMode: "call_mode",
|
||||
ApiKey: "api_key",
|
||||
Enabled: "enabled",
|
||||
MaxConcurrency: "max_concurrency",
|
||||
TimeoutSeconds: "timeout_seconds",
|
||||
RetryTimes: "retry_times",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
IsOwner: "is_owner",
|
||||
OperatorName: "operator_name",
|
||||
TokenConfig: "token_config",
|
||||
ExtendMapping: "extend_mapping",
|
||||
QueryConfig: "query_config",
|
||||
StreamConfig: "stream_config",
|
||||
FirstFrame: "first_frame",
|
||||
LastFrame: "last_frame",
|
||||
CallbackUrl: "callback_url",
|
||||
}
|
||||
|
||||
// AsynchModel 异步模型配置
|
||||
type AsynchModel struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
||||
Form any `orm:"form_json" json:"form"`
|
||||
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"` // 1=当前用户创建的,0=超级管理员的
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||||
Form []map[string]any `orm:"form_json" json:"form"`
|
||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody string `orm:"response_body" json:"responseBody"`
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
CallMode *int `orm:"call_mode" json:"callMode"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type asynchModelTypeCol struct {
|
||||
beans.SQLBaseCol
|
||||
TypeID string
|
||||
TypeName string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var AsynchModelTypeCol = asynchModelTypeCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TypeID: "type_id",
|
||||
TypeName: "type_name",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
// AsynchModelType 模型类型(图片/音频/视频等)
|
||||
type AsynchModelType struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
TypeID int `orm:"type_id" json:"typeId"`
|
||||
TypeName string `orm:"type_name" json:"type"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
@@ -62,28 +62,28 @@ var AsynchTaskCol = asynchTaskCol{
|
||||
// AsynchTask 异步任务
|
||||
type AsynchTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
TaskID string `orm:"task_id" json:"taskId"`
|
||||
BizName string `orm:"biz_name" json:"bizName"`
|
||||
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
|
||||
ModelKey string `orm:"model_key" json:"modelKey"`
|
||||
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
|
||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||
FileType string `orm:"file_type" json:"fileType"`
|
||||
FileSize int64 `orm:"file_size" json:"fileSize"`
|
||||
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
|
||||
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
|
||||
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
|
||||
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
|
||||
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
|
||||
RetryCount int `orm:"retry_count" json:"retryCount"`
|
||||
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
|
||||
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
|
||||
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
|
||||
InputRef string `orm:"input_ref" json:"inputRef"`
|
||||
RequestPayload any `orm:"request_payload" json:"requestPayload"`
|
||||
TextResult string `orm:"text_result" json:"text"`
|
||||
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID(用于标识同一轮次的任务)
|
||||
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
TaskID string `orm:"task_id" json:"taskId"`
|
||||
BizName string `orm:"biz_name" json:"bizName"`
|
||||
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
|
||||
ModelKey string `orm:"model_key" json:"modelKey"`
|
||||
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
|
||||
OssFile string `orm:"oss_file" json:"ossFile"`
|
||||
FileType string `orm:"file_type" json:"fileType"`
|
||||
FileSize int64 `orm:"file_size" json:"fileSize"`
|
||||
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
|
||||
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
|
||||
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
|
||||
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
|
||||
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
|
||||
RetryCount int `orm:"retry_count" json:"retryCount"`
|
||||
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
|
||||
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
|
||||
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
|
||||
InputRef string `orm:"input_ref" json:"inputRef"`
|
||||
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
|
||||
TextResult map[string]any `orm:"text_result" json:"text"`
|
||||
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID(用于标识同一轮次的任务)
|
||||
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type LogsModelPpCol struct {
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// triggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
callbackURL := t.BizName + t.CallbackURL
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// triggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// http.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
195
service/gateway/gateway_http_service.go
Normal file
195
service/gateway/gateway_http_service.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/entity"
|
||||
"time"
|
||||
|
||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
type UploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err = writer.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp UploadFileResponse
|
||||
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if &resp == nil {
|
||||
return nil, errors.New("[OSS] 上传文件失败")
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] 上传成功 url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CallbackPayload 回调请求体
|
||||
type CallbackPayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State int `json:"state"`
|
||||
OssFile string `json:"oss_file"`
|
||||
FileType string `json:"file_type"`
|
||||
Messages map[string]any `json:"messages"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// TriggerCallback 任务的回调
|
||||
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp struct{}
|
||||
payload := CallbackPayload{
|
||||
TaskId: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
FileType: t.FileType,
|
||||
Messages: t.TextResult,
|
||||
ErrorMsg: t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, t.CallbackURL, headers, &resp, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// PromptsCallbackPayload 提示词回调请求体
|
||||
type PromptsCallbackPayload struct {
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
Messages map[string]any `json:"messages"`
|
||||
}
|
||||
|
||||
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/callback"
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var resp struct{}
|
||||
payload := PromptsCallbackPayload{
|
||||
EpicycleId: epicycleId,
|
||||
Messages: t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, callbackURL, headers, &resp, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
|
||||
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
|
||||
// if callbackURL == "" {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
|
||||
// if task == nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
|
||||
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
|
||||
// for i := range detailList {
|
||||
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
|
||||
// }
|
||||
//
|
||||
// // 构建与查询接口一致的 taskInfo
|
||||
// taskInfo := dao.EntityToItem(task)
|
||||
//
|
||||
// // 兼容历史数据: 从 result 中补全 scenes 等字段
|
||||
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
|
||||
//
|
||||
// payload := dto.CallbackPayload{
|
||||
// TaskInfo: taskInfo,
|
||||
// DetailList: detailItems,
|
||||
// }
|
||||
//
|
||||
// body, _ := json.Marshal(payload)
|
||||
//
|
||||
// // 透传调用方的用户信息
|
||||
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
|
||||
//
|
||||
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
|
||||
// req.Header.Set("Content-Type", "application/json")
|
||||
// req.Header.Set("X-User-Info", string(userJSON))
|
||||
//
|
||||
// resp, reqErr := http.DefaultClient.Do(req)
|
||||
// if reqErr != nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
|
||||
// return
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// respBody, _ := io.ReadAll(resp.Body)
|
||||
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
|
||||
//}
|
||||
@@ -1,53 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
@@ -1,7 +1,10 @@
|
||||
package service
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/queue"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -14,35 +17,36 @@ var Cleaner = &cleaner{}
|
||||
type cleaner struct{}
|
||||
|
||||
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
||||
func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
|
||||
}
|
||||
|
||||
// 2) 超时任务标失败
|
||||
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range list {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
t.ErrorMsg = "任务超时自动失败"
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
|
||||
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
|
||||
}
|
||||
|
||||
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
||||
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range retryable {
|
||||
// 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试
|
||||
@@ -51,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
if err != nil || m == nil {
|
||||
continue
|
||||
}
|
||||
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
|
||||
if limit > 0 {
|
||||
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
||||
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -73,20 +77,23 @@ func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
}
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
||||
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
|
||||
}
|
||||
|
||||
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
|
||||
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
|
||||
}
|
||||
return &dto.CleanWorkRes{
|
||||
Ok: true,
|
||||
}, nil
|
||||
}
|
||||
256
service/model/model_service.go
Normal file
256
service/model/model_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// Create 创建模型
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
|
||||
// 1)如果设为会话模型,先把该用户旧会话模型取消
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
if err := s.clearUserChatModel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// 2)判断是否超管,决定 isOwner
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
// 3)入库
|
||||
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
// Update 更新模型配置
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
// 1)会话模型唯一性校验
|
||||
if req.IsChatModel != nil && *req.IsChatModel == 1 {
|
||||
if err := s.checkChatModelUnique(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 2)超管创建/普通用户更新
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
// 3)跨租户判断:超管的模型不允许直接修改,走插入新记录
|
||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model.TenantId == 1 {
|
||||
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除模型
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Get 获取模型详情
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.IsEmpty(req.ID) {
|
||||
req.Creator = user.UserName
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Id: req.ID,
|
||||
Creator: user.UserName,
|
||||
},
|
||||
ModelName: req.ModelName,
|
||||
IsChatModel: req.IsChatModel,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// List 获取模型列表
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
|
||||
// 1)判断超管
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
// 2)获取当前用户
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
// 3)查询
|
||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.ListModelRes{List: models, Total: total}, nil
|
||||
}
|
||||
|
||||
// UpdateChatModel 设置会话模型
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 1)校验新模型存在
|
||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
})
|
||||
if err != nil || newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 2)获取当前用户的会话模型
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3)事务:取消旧的 + 设置新的
|
||||
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
if !g.IsEmpty(currentModel) {
|
||||
if currentModel.ModelType != public.ModelTypeInference {
|
||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||
}
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// GetIsChatModel 获取当前用户会话模型
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
if err != nil || model == nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetIsChatModelRes{Model: model}, nil
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// clearUserChatModel 清除当前用户旧会话模型
|
||||
func (s *modelService) clearUserChatModel(ctx context.Context) error {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
if err != nil || model == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// checkChatModelUnique 校验用户是否已有会话模型
|
||||
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model != nil {
|
||||
return errors.New("用户已存在会话模型")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||
// 返回副本,避免外部修改
|
||||
types := make(map[int]string, len(public.ModelTypeName))
|
||||
for k, v := range public.ModelTypeName {
|
||||
types[k] = v
|
||||
}
|
||||
return &dto.TypeItem{
|
||||
Type: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetOperatorList 获取运营商列表
|
||||
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
||||
return &dto.ListOperatorRes{
|
||||
List: public.OperatorList,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,417 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func parseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
if payload == nil {
|
||||
return url.Values{}, nil
|
||||
}
|
||||
// 统一转成 map[string]any
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := map[string]any{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := url.Values{}
|
||||
for k, v := range m {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
// 复杂类型直接 json 字符串化
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
q.Set(k, vv)
|
||||
case float64, bool, int, int64, uint64:
|
||||
q.Set(k, fmt.Sprintf("%v", vv))
|
||||
default:
|
||||
bs, _ := json.Marshal(v)
|
||||
q.Set(k, string(bs))
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。
|
||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
|
||||
// ============ 新增:请求参数映射 ============
|
||||
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(m.BaseURL, "/")
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(url, "?") {
|
||||
url = url + "&" + q.Encode()
|
||||
} else {
|
||||
url = url + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key)
|
||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
// 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。
|
||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
|
||||
// ============ 新增:响应参数映射 ============
|
||||
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
|
||||
if err != nil {
|
||||
// 响应映射失败不阻塞,返回原始数据
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
// =========================================
|
||||
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 映射相关函数
|
||||
// ============================================
|
||||
|
||||
// mapRequestPayload 将标准请求映射为模型特定格式
|
||||
func mapRequestPayload(mappingAny any, payload any) (any, error) {
|
||||
// 1. 解析请求映射配置(值是any类型,支持bool、number等)
|
||||
mapping, err := parseRequestMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果没有映射配置,直接返回原始payload
|
||||
if len(mapping) == 0 {
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 2. 将payload转为map
|
||||
var payloadMap map[string]any
|
||||
switch v := payload.(type) {
|
||||
case map[string]any:
|
||||
payloadMap = v
|
||||
case []map[string]any:
|
||||
// 如果传进来的是纯messages数组,包装成标准格式
|
||||
payloadMap = map[string]any{
|
||||
"messages": v,
|
||||
}
|
||||
default:
|
||||
// 通过JSON转换
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
|
||||
return nil, fmt.Errorf("反序列化payload失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用数据库固定参数覆盖/补充
|
||||
for key, value := range mapping {
|
||||
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
|
||||
payloadMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return payloadMap, nil
|
||||
}
|
||||
|
||||
// mapResponsePayload 将模型响应映射为标准格式
|
||||
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
|
||||
mapping, err := parseResponseMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseStr := string(responseBytes)
|
||||
resultStr := `{}`
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
value := gjson.Get(responseStr, modelPath)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(resultStr), nil
|
||||
}
|
||||
|
||||
func parseRequestMapping(mappingAny any) (map[string]any, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
// 尝试转成 map
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
result[k] = val
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
// 尝试转成 string
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return nil, nil
|
||||
// =======================================================
|
||||
|
||||
case map[string]interface{}:
|
||||
result = v
|
||||
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseResponseMapping 解析响应映射配置
|
||||
// 返回值类型为 map[string]string,值都是JSON路径字符串
|
||||
func parseResponseMapping(mappingAny any) (map[string]string, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mapping := make(map[string]string)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
return nil, nil
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case map[string]interface{}:
|
||||
// 数据库JSONB直接返回的map
|
||||
for k, val := range v {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
// isEmptyValue 判断值是否为空
|
||||
func isEmptyValue(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val == ""
|
||||
case []any:
|
||||
return len(val) == 0
|
||||
case map[string]any:
|
||||
return len(val) == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := forwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
// 获取当前会话模型
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
var model *entity.AsynchModel
|
||||
model, err = dao.Model.GetByIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 如果有会话模型,那就改变为 0
|
||||
if model != nil {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: model.Id,
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
//根据当前 isChatModel 来判断是否更新模型
|
||||
if req.IsChatModel == gconv.PtrInt(1) {
|
||||
//判断当前用户是否有会话模型
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model != nil {
|
||||
return errors.New("用户已存在会话模型,不能创建")
|
||||
}
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||||
var count int
|
||||
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
|
||||
ID: req.ID,
|
||||
Creator: user.UserName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
insertDto := new(dto.CreateModelReq)
|
||||
err = gconv.Struct(req, insertDto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Insert(ctx, insertDto)
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
var models []*entity.AsynchModel
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理列表中每条记录的 JSONB 字段
|
||||
for _, m := range models {
|
||||
m.Form = ParseJSONField(m.Form)
|
||||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
||||
}
|
||||
return models, total, nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
||||
typeMap := make(map[int]string)
|
||||
|
||||
// 读取配置
|
||||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
||||
for k, v := range configMap {
|
||||
typeID := gconv.Int(k)
|
||||
typeName := gconv.String(v)
|
||||
if typeID > 0 && typeName != "" {
|
||||
typeMap[typeID] = typeName
|
||||
}
|
||||
}
|
||||
// 如果配置为空,使用默认值
|
||||
if len(typeMap) == 0 {
|
||||
typeMap = map[int]string{
|
||||
1: "推理模型",
|
||||
2: "图片模型",
|
||||
3: "音频模型",
|
||||
4: "向量化模型",
|
||||
5: "全模态模型",
|
||||
}
|
||||
}
|
||||
return typeMap
|
||||
}
|
||||
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 校验新会话模型是否存在
|
||||
newModel, err := dao.Model.Get(ctx, req.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 获取当前用户会话模型
|
||||
currentModel, err := dao.Model.GetByIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
if !g.IsEmpty(currentModel) {
|
||||
if currentModel.ModelType != 1 {
|
||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||
}
|
||||
|
||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: currentModel.Id,
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置当前为会话模型(设为1)
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: req.Id,
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package service
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"model-gateway/model/dto"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
@@ -26,7 +28,6 @@ type AutoTuneResult struct {
|
||||
OldQueueLimit int `json:"oldQueueLimit"` // 调参前运行时值(Redis),若无则等于 cap
|
||||
NewQueueLimit int `json:"newQueueLimit"` // 本次计算出的运行时值(将写入 Redis),受 ±50% 约束且不超过 cap
|
||||
|
||||
ExpectedSeconds int `json:"expectedSeconds"` // 模型预计执行时间(秒):asynch_models.expected_seconds(用于 queue_limit 计算绑定)
|
||||
}
|
||||
|
||||
// AutoTune 由上层定时任务通过接口触发:
|
||||
@@ -34,9 +35,12 @@ type AutoTuneResult struct {
|
||||
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
||||
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
||||
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
||||
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 3600
|
||||
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request cannot be nil")
|
||||
}
|
||||
if req.WindowSeconds <= 0 {
|
||||
req.WindowSeconds = 3600 // 默认1小时
|
||||
}
|
||||
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
||||
var modelRows []*entity.AsynchModel
|
||||
@@ -60,15 +64,15 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
|
||||
if m.MaxConcurrency > cur.MaxConcurrency {
|
||||
cur.MaxConcurrency = m.MaxConcurrency
|
||||
}
|
||||
if m.QueueLimit > cur.QueueLimit {
|
||||
cur.QueueLimit = m.QueueLimit
|
||||
if m.MaxConcurrency*2 > cur.MaxConcurrency*2 {
|
||||
cur.MaxConcurrency = m.MaxConcurrency
|
||||
}
|
||||
if m.ExpectedSeconds > cur.ExpectedSeconds {
|
||||
cur.ExpectedSeconds = m.ExpectedSeconds
|
||||
if m.TimeoutSeconds > cur.TimeoutSeconds {
|
||||
cur.TimeoutSeconds = m.TimeoutSeconds
|
||||
}
|
||||
}
|
||||
if len(modelMap) == 0 {
|
||||
return []AutoTuneResult{}, nil
|
||||
return nil, errors.New("no models found")
|
||||
}
|
||||
|
||||
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
||||
@@ -89,7 +93,7 @@ SELECT model_name,
|
||||
AND finished_at IS NOT NULL
|
||||
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
||||
GROUP BY model_name`, public.TableNameTask)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -108,7 +112,7 @@ SELECT model_name,
|
||||
for modelName, m := range modelMap {
|
||||
s := statMap[modelName]
|
||||
capMax := m.MaxConcurrency
|
||||
capQueue := m.QueueLimit
|
||||
capQueue := m.MaxConcurrency * 2
|
||||
oldMax := GetRuntimeMaxConcurrency(ctx, modelName, capMax)
|
||||
oldQueue := GetRuntimeQueueLimit(ctx, modelName, capQueue)
|
||||
|
||||
@@ -124,7 +128,6 @@ SELECT model_name,
|
||||
CapQueueLimit: capQueue,
|
||||
OldQueueLimit: oldQueue,
|
||||
NewQueueLimit: oldQueue,
|
||||
ExpectedSeconds: m.ExpectedSeconds,
|
||||
})
|
||||
continue
|
||||
}
|
||||
@@ -150,7 +153,7 @@ SELECT model_name,
|
||||
setRuntimeInt(ctx, runtimeMaxConcurrencyKey(modelName), newMax)
|
||||
|
||||
// queue_limit:W_target = expected_seconds * queueFactor
|
||||
exp := m.ExpectedSeconds
|
||||
exp := m.TimeoutSeconds
|
||||
if exp <= 0 {
|
||||
exp = 60
|
||||
}
|
||||
@@ -185,10 +188,11 @@ SELECT model_name,
|
||||
CapQueueLimit: capQueue,
|
||||
OldQueueLimit: oldQueue,
|
||||
NewQueueLimit: newQueue,
|
||||
ExpectedSeconds: m.ExpectedSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
|
||||
return out, nil
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
|
||||
return &dto.AutoTuneRes{
|
||||
List: out,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
// 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。
|
||||
|
||||
const (
|
||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||
)
|
||||
|
||||
func runtimeMaxConcurrencyKey(modelName string) string {
|
||||
@@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -34,7 +34,8 @@ end
|
||||
return 1
|
||||
`
|
||||
|
||||
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
// AcquireSemaphore 获取并发令牌
|
||||
func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
if max <= 0 {
|
||||
// 不限制
|
||||
return true, nil
|
||||
@@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64
|
||||
return gconv.Int(r) == 1, nil
|
||||
}
|
||||
|
||||
func releaseSemaphore(ctx context.Context, key string) error {
|
||||
// ReleaseSemaphore 释放并发令牌
|
||||
func ReleaseSemaphore(ctx context.Context, key string) error {
|
||||
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package stat
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,18 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
)
|
||||
|
||||
// StorageService 结果存储(OSS/MinIO)抽象
|
||||
type StorageService interface {
|
||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
||||
}
|
||||
|
||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
||||
var Storage StorageService = &ossStorage{}
|
||||
|
||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
||||
@@ -1,81 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
type ossStorage struct{}
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := forwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
305
service/task/task_service.go
Normal file
305
service/task/task_service.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/service/queue"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// Create 创建任务
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
startAt := time.Now()
|
||||
taskID := uuid.NewString()
|
||||
// 1) 检查模型配置,并且获取模型
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
TenantId: userInfo.TenantId,
|
||||
Creator: userInfo.UserName,
|
||||
},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil || (model.Enabled != nil && *model.Enabled != 1) {
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
||||
if limit > 0 {
|
||||
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 插入任务记录
|
||||
if model.CallMode != nil && *model.CallMode == public.CallModeAsync {
|
||||
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
|
||||
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
|
||||
}
|
||||
storedPayload := map[string]any{
|
||||
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
|
||||
"body": req.RequestPayload,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
ModelKey: model.ApiKey,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
EpicycleId: req.EpicycleId,
|
||||
})
|
||||
if err != nil { // 入库失败:回滚闸门占位
|
||||
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = utils.GetLocalIP()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
ErrorMsg: "",
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: storedPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
|
||||
// 5) 获取任务信息
|
||||
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if task == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5) 创建成功后立即异步尝试执行当前任务
|
||||
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
||||
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
|
||||
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
|
||||
// 1. 查本地任务
|
||||
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: req.TaskID,
|
||||
})
|
||||
if err != nil || task == nil {
|
||||
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
|
||||
}
|
||||
|
||||
// 2. 成功:取 video_url 和 usage
|
||||
if req.Status == "succeeded" {
|
||||
result := map[string]any{
|
||||
"video_url": req.Content["video_url"],
|
||||
"usage": req.Usage,
|
||||
}
|
||||
NotifyAsyncResult(req.TaskID, result, nil)
|
||||
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
||||
}
|
||||
|
||||
// 3. 失败/过期
|
||||
if req.Status == "failed" || req.Status == "expired" {
|
||||
NotifyAsyncResult(req.TaskID, nil, fmt.Errorf(req.Status))
|
||||
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
||||
}
|
||||
|
||||
return &dto.ModelTaskCallbackRes{Success: true}, nil
|
||||
}
|
||||
|
||||
// QueryPendingTasks 批量轮询进行中的异步任务
|
||||
func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (*dto.QueryPendingTasksRes, error) {
|
||||
limit := req.Limit
|
||||
if limit <= 0 {
|
||||
limit = g.Cfg().MustGet(ctx, "asynch.queryPending.limit", 10).Int()
|
||||
}
|
||||
|
||||
// 1. 查 state=1(执行中)的异步任务
|
||||
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 逐个查询
|
||||
var results []dto.QueryTaskItem
|
||||
for _, t := range tasks {
|
||||
// 拿到模型配置
|
||||
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil || model == nil || model.QueryConfig == nil {
|
||||
continue
|
||||
}
|
||||
result, err := util.PullTaskResult(ctx, nil, model.QueryConfig, model.HeadMsg)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
status := gconv.String(result["status"])
|
||||
item := dto.QueryTaskItem{
|
||||
TaskID: t.TaskID,
|
||||
Status: status,
|
||||
Content: result["content"].(map[string]any),
|
||||
Usage: result["usage"].(map[string]any),
|
||||
}
|
||||
results = append(results, item)
|
||||
|
||||
// 如果任务完成,通知等待通道
|
||||
if status == "succeeded" || status == "failed" || status == "expired" {
|
||||
NotifyAsyncResult(t.TaskID, result["content"].(map[string]any), nil)
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.QueryPendingTasksRes{
|
||||
Total: len(results),
|
||||
Results: results,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetResult 获取任务结果
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, errors.New("任务不存在")
|
||||
}
|
||||
return &dto.GetTaskResultRes{
|
||||
OssFile: t.OssFile,
|
||||
State: t.State,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
if req == nil || len(req.TaskIDs) == 0 {
|
||||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||||
}
|
||||
// 1) 先查当前租户下的任务列表
|
||||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||||
now := time.Now()
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.State != 2 {
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: t.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retainSeconds := 86400
|
||||
if m != nil && m.AutoCleanSeconds > 0 {
|
||||
retainSeconds = m.AutoCleanSeconds
|
||||
}
|
||||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||||
|
||||
// 为了本次返回一致性,内存里也更新
|
||||
t.State = 4
|
||||
t.ExpireAt = expireAt
|
||||
}
|
||||
|
||||
// 3) 组装返回
|
||||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, dto.GetTaskBatchItem{
|
||||
TaskID: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
})
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
// List 获取任务列表
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil {
|
||||
if req.PageNum > 0 {
|
||||
pageNum = req.PageNum
|
||||
}
|
||||
if req.PageSize > 0 {
|
||||
pageSize = req.PageSize
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
taskID := ""
|
||||
var state *int
|
||||
if req != nil {
|
||||
modelName = req.ModelName
|
||||
taskID = req.TaskID
|
||||
state = req.State
|
||||
}
|
||||
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListTaskRes{List: list, Total: total}, nil
|
||||
}
|
||||
494
service/task/worker.go
Normal file
494
service/task/worker.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
"model-gateway/service/queue"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var AsyncWorker = &asyncWorker{}
|
||||
|
||||
type asyncWorker struct {
|
||||
}
|
||||
|
||||
// handleOne 执行一次完整的任务
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
|
||||
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
|
||||
maxRetry := model.RetryTimes // 重试次数
|
||||
startTime := time.Now()
|
||||
|
||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||
|
||||
// 1) 分布式并发控制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
||||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||
if err != nil {
|
||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||||
_ = w.rollbackToPending(ctx, task.Id)
|
||||
return
|
||||
}
|
||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||
|
||||
// 2) 调用模型
|
||||
switch {
|
||||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||||
rawBytes, err := w.callModelStream(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
default:
|
||||
body, err = w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 保存临时文件
|
||||
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
|
||||
if err == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
|
||||
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
|
||||
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
|
||||
if err != nil {
|
||||
task.TextResult = body
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 上传 OSS(可重试)
|
||||
var oss *gateway.UploadFileResponse
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
oss, err = w.uploadOSS(ctx, task)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
|
||||
task.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
|
||||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 6) 成功回调
|
||||
task.State = 2
|
||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
task.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||||
task.FileType = oss.FileFormat
|
||||
task.TextResult = body
|
||||
task.FileSize = int64(oss.FileSize)
|
||||
|
||||
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
return
|
||||
}
|
||||
|
||||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), task)
|
||||
if req.EpicycleId != 0 {
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
|
||||
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
|
||||
|
||||
// 7) 删除临时文件
|
||||
_ = os.Remove(task.TmpFile)
|
||||
}
|
||||
|
||||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||
data, err = os.ReadFile(task.TmpFile)
|
||||
if err != nil || len(data) == 0 {
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, "")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// asyncResult 异步任务结果
|
||||
type asyncResult struct {
|
||||
result map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
// asyncTaskChan 全局异步任务等待通道
|
||||
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
|
||||
|
||||
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
// 1. 提交异步任务
|
||||
body, err := w.callModel(ctx, task, model, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 2. 拿到 task_id
|
||||
taskID := gjson.New(body).Get(model.ResponseBody).String()
|
||||
|
||||
// 3. 创建等待通道
|
||||
ch := make(chan asyncResult, 1)
|
||||
asyncTaskChan.Store(taskID, ch)
|
||||
defer func() {
|
||||
asyncTaskChan.Delete(taskID)
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// 4. 阻塞等待回调或超时
|
||||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
g.Log().Infof(ctx, "[异步任务] 开始等待结果 taskID=%s timeout=%v", taskID, timeout)
|
||||
|
||||
select {
|
||||
case res, ok := <-ch:
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("异步任务通道已关闭: taskID=%s", taskID)
|
||||
}
|
||||
g.Log().Infof(ctx, "[异步任务] 获取结果成功 taskID=%s", taskID)
|
||||
return res.result, res.err
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("异步任务超时: taskID=%s", taskID)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyAsyncResult 回调接口调用此方法通知结果
|
||||
func NotifyAsyncResult(taskID string, result map[string]any, err error) {
|
||||
if ch, ok := asyncTaskChan.Load(taskID); ok {
|
||||
ch.(chan asyncResult) <- asyncResult{result: result, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||
// 返回: 解析后的响应体, error
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
// 1) 如果已有临时文件且 phase=1,直接读取
|
||||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||
data, err = os.ReadFile(task.TmpFile)
|
||||
if err != nil || len(data) == 0 {
|
||||
g.Log().Warningf(ctx, "[callModel] 读取临时文件失败,重新调用模型 taskId=%s err=%v", task.TaskID, err)
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 没有可用数据,调用模型
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, body, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 检测文件类型,保存临时文件
|
||||
_, ext := util.DetectFileType(data)
|
||||
tmpPath, tmpErr := util.SaveTmpResult(task.TaskID, data, ext)
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 检测文件类型,提取文本结果
|
||||
contentType, _ := util.DetectFileType(data)
|
||||
var textResult string
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
|
||||
// 5) 非文本内容,返回错误
|
||||
if textResult == "" {
|
||||
return nil, fmt.Errorf("模型返回非文本内容,contentType=%s", contentType)
|
||||
}
|
||||
|
||||
// 6) 解析并返回
|
||||
return gjson.New(textResult).Map(), nil
|
||||
}
|
||||
|
||||
// parseAndRetry 解析模型返回结果,并重试
|
||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
|
||||
// 1) 响应映射
|
||||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 2) 先存 token 到数据库,防止后续失败丢失
|
||||
if tokens, ok := mapped[model.ResponseTokenField]; ok {
|
||||
task.ExpendTokens = gconv.Int64(tokens)
|
||||
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
|
||||
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
|
||||
})
|
||||
}
|
||||
|
||||
// 3) 解析 + 校验
|
||||
var parsed map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
parsed, err = util.ParseAndValidate(mapped, model)
|
||||
if err == nil {
|
||||
return parsed, nil
|
||||
}
|
||||
case public.BuildTypeStruct:
|
||||
parsed = util.ParseStructResult(mapped, model.ResponseBody)
|
||||
return parsed, nil
|
||||
default:
|
||||
return mapped, nil
|
||||
}
|
||||
|
||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||
|
||||
if attempt == maxRetry {
|
||||
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
|
||||
}
|
||||
|
||||
// 4) 重新调模型(直接调,不走缓存)
|
||||
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
|
||||
reqBody := util.GetModelBody(task.RequestPayload)
|
||||
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
|
||||
if callErr != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
||||
continue
|
||||
}
|
||||
|
||||
// 5) 解析原始响应,覆盖 body 进入下一轮
|
||||
var rawResp map[string]any
|
||||
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||
continue
|
||||
}
|
||||
body = rawResp
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
|
||||
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
|
||||
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||||
|
||||
// 2)构建请求 URL 和超时
|
||||
baseURL := strings.TrimRight(model.BaseURL, "/")
|
||||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||
client := &http.Client{Timeout: timeout}
|
||||
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
|
||||
|
||||
// 3)构建 HTTP 请求
|
||||
var req *http.Request
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := util.BodyToQuery(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(baseURL, "?") {
|
||||
baseURL = baseURL + "&" + q.Encode()
|
||||
} else {
|
||||
baseURL = baseURL + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if modelKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+modelKey)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// 5)发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 6)读取响应体
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 7)检查 HTTP 状态码
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// // InvokeModel 调用模型服务,返回二进制结果
|
||||
//
|
||||
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
// if m == nil || m.BaseURL == "" {
|
||||
// return nil, fmt.Errorf("模型配置不完整")
|
||||
// }
|
||||
// // 请求参数映射
|
||||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
// }
|
||||
// // 合并请求头
|
||||
// headers := util.ForwardHeaders(ctx)
|
||||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
//
|
||||
// // 设置超时
|
||||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
// if timeout <= 0 {
|
||||
// timeout = 600 * time.Second
|
||||
// }
|
||||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
// defer cancel()
|
||||
//
|
||||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
// if method == "" {
|
||||
// method = http.MethodPost
|
||||
// }
|
||||
//
|
||||
// var respBytes []byte
|
||||
//
|
||||
// switch method {
|
||||
// case http.MethodGet:
|
||||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// default:
|
||||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// }
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // 响应参数映射
|
||||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||||
// if err != nil {
|
||||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
// return respBytes, nil
|
||||
// }
|
||||
// return mappedResponse, nil
|
||||
// }
|
||||
|
||||
// uploadOSS 从临时文件上传 OSS
|
||||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||||
data, err := os.ReadFile(t.TmpFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取临时文件失败: %w", err)
|
||||
}
|
||||
_, ext := util.DetectFileType(data)
|
||||
return gateway.UploadByTask(ctx, data, ext)
|
||||
}
|
||||
|
||||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
}
|
||||
|
||||
// rollbackToPending 恢复任务状态为 PENDING
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
@@ -1,266 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
fmt.Printf("打印请求:%+v", req)
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
taskID := uuid.NewString()
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
}
|
||||
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
ModelKey: m.ApiKey,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
EpicycleId: req.EpicycleId,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, t)
|
||||
if err != nil {
|
||||
// 入库失败:回滚闸门占位
|
||||
ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 写操作日志(尽量不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = r.GetClientIp()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
ErrorMsg: "",
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: storedPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
|
||||
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
||||
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
|
||||
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
|
||||
// - 目标:尽快把刚创建的任务拉起来执行
|
||||
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
||||
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
||||
// - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
|
||||
if taskID == "" {
|
||||
return
|
||||
}
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
|
||||
if interval <= 0 {
|
||||
interval = 5
|
||||
}
|
||||
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
return true
|
||||
}
|
||||
if t == nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
|
||||
return true
|
||||
}
|
||||
switch t.State {
|
||||
case 0:
|
||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
}
|
||||
return false
|
||||
case 1:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
|
||||
return true
|
||||
case 2, 3, 4:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
|
||||
return true
|
||||
default:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 先立即尝试一次
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, errors.New("任务不存在")
|
||||
}
|
||||
return &dto.GetTaskResultRes{
|
||||
OssFile: t.OssFile,
|
||||
State: t.State,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
if req == nil || len(req.TaskIDs) == 0 {
|
||||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||||
}
|
||||
// 1) 先查当前租户下的任务列表
|
||||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||||
now := time.Now()
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.State != 2 {
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retainSeconds := 86400
|
||||
if m != nil && m.AutoCleanSeconds > 0 {
|
||||
retainSeconds = m.AutoCleanSeconds
|
||||
}
|
||||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||||
|
||||
// 为了本次返回一致性,内存里也更新
|
||||
t.State = 4
|
||||
t.ExpireAt = expireAt
|
||||
}
|
||||
|
||||
// 3) 组装返回
|
||||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, dto.GetTaskBatchItem{
|
||||
TaskID: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
})
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil {
|
||||
if req.PageNum > 0 {
|
||||
pageNum = req.PageNum
|
||||
}
|
||||
if req.PageSize > 0 {
|
||||
pageSize = req.PageSize
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
taskID := ""
|
||||
var state *int
|
||||
if req != nil {
|
||||
modelName = req.ModelName
|
||||
taskID = req.TaskID
|
||||
state = req.State
|
||||
}
|
||||
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListTaskRes{List: list, Total: total}, nil
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func loadTmpResult(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func deleteTmpResult(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
|
||||
113
service/utils.go
113
service/utils.go
@@ -1,113 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var AsyncWorker = &asyncWorker{}
|
||||
|
||||
type asyncWorker struct {
|
||||
}
|
||||
|
||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||
// - batchSize: 本次抢占数量
|
||||
// - goroutines: 本次并发数(协程池大小)
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10
|
||||
}
|
||||
if goroutines <= 0 {
|
||||
goroutines = 1
|
||||
}
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
pool := grpool.New(goroutines)
|
||||
defer pool.Close()
|
||||
|
||||
claimed = len(tasks)
|
||||
done := make(chan struct{}, claimed)
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
||||
w.handleOne(ctx, task, 0)
|
||||
done <- struct{}{}
|
||||
}, func(ctx context.Context, e error) {
|
||||
if e != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", e))
|
||||
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
}
|
||||
done <- struct{}{}
|
||||
})
|
||||
}
|
||||
for i := 0; i < claimed; i++ {
|
||||
<-done
|
||||
}
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
// - 只定向抢占当前 taskId 对应的 pending 任务
|
||||
// - 若任务已被其它 worker 抢走/已不在 pending,则直接返回
|
||||
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
|
||||
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
w.handleOne(ctx, task, epicycleId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
if m == nil || (m.Enabled != nil && *m.Enabled != 1) {
|
||||
errMsg := "模型不存在或未启用"
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 分布式并发限制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||
leaseSeconds := int64(3600)
|
||||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
|
||||
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
// 并发满了:放回排队,不回调(不是失败)
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = releaseSemaphore(ctx, semKey)
|
||||
}()
|
||||
|
||||
// 3) 调用模型服务
|
||||
if payload == nil {
|
||||
payload = map[string]any{
|
||||
"taskId": t.TaskID,
|
||||
"inputRef": t.InputRef,
|
||||
}
|
||||
}
|
||||
var (
|
||||
data []byte
|
||||
contentType string
|
||||
ext string
|
||||
textResult string
|
||||
)
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
} else {
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
// 统计
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
||||
// 核心调用
|
||||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
||||
if err == nil && tmpPath != "" {
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ OSS失败不回调(还会重试) ============
|
||||
// 注意:OSS失败保留临时文件,下次重试,所以这里不触发最终回调
|
||||
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 更新任务状态成功
|
||||
fileType := strings.TrimPrefix(ext, ".")
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(
|
||||
ctx,
|
||||
t.Id,
|
||||
ossURL,
|
||||
fileType,
|
||||
textResult,
|
||||
int64(len(data)),
|
||||
nil,
|
||||
GetExpendTokens(m.TokenMapping, textResult),
|
||||
); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 成功/失败均不再占用 queue_limit
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
|
||||
// 6) 成功回调
|
||||
t.State = 2
|
||||
t.OssFile = ossURL
|
||||
t.FileType = fileType
|
||||
t.TextResult = textResult
|
||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||
if epicycleId != 0 {
|
||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
}
|
||||
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
|
||||
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
|
||||
func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
} else {
|
||||
return len(textResult)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
@@ -1 +0,0 @@
|
||||
Asia/Shanghai
|
||||
123
update.sql
123
update.sql
@@ -8,41 +8,57 @@
|
||||
-- 1) asynch_models
|
||||
-- =========================
|
||||
CREATE TABLE IF NOT EXISTS asynch_models (
|
||||
-- 基础字段
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
|
||||
creator VARCHAR(64) NOT NULL, -- 创建人
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
|
||||
updater VARCHAR(64) NOT NULL, -- 更新人
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
|
||||
deleted_at TIMESTAMP(6), -- 删除时间(软删)
|
||||
-- 业务字段
|
||||
model_name VARCHAR(128) NOT NULL, -- 模型名称
|
||||
model_type SMALLINT NOT NULL DEFAULT 0, -- 模型类型
|
||||
base_url VARCHAR(256) NOT NULL, -- 模型地址
|
||||
http_method VARCHAR(8) NOT NULL DEFAULT 'POST', -- 请求方式 GET/POST
|
||||
head_msg VARCHAR(1024) DEFAULT '', -- 请求头绑定(支持多个,逗号分隔)示例 X-API:xxx,operation:true
|
||||
is_private SMALLINT NOT NULL DEFAULT 0, -- 是否私有化 0-私有 1-公共
|
||||
enabled SMALLINT NOT NULL DEFAULT 1, -- 是否启用 0停用 1-启用
|
||||
is_chat_model SMALLINT NOT NULL DEFAULT 0, -- 是否为对话模型 0-否 1-是
|
||||
is_owner SMALLINT NOT NULL DEFAULT 99, -- 1=当前用户创建的,0=超级管理员的
|
||||
api_key VARCHAR(256) NOT NULL DEFAULT '', -- 调用凭证,密钥
|
||||
prompt TEXT NOT NULL DEFAULT '', -- 提示词内容(文本)
|
||||
form_json JSONB NOT NULL DEFAULT '{}'::jsonb, -- 表单结构(用于前端渲染)
|
||||
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb -- 请求映射
|
||||
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, -- 返回映射
|
||||
response_body JSONB NOT NULL DEFAULT '{}'::jsonb, -- 返回主体
|
||||
max_concurrency INT NOT NULL DEFAULT 10, -- 单模型最大并发
|
||||
queue_limit INT NOT NULL DEFAULT 1000, -- 排队上限(近似控制)
|
||||
timeout_seconds INT NOT NULL DEFAULT 600, -- 调用模型服务超时(秒)
|
||||
expected_seconds INT NOT NULL DEFAULT 600, -- 模型预计执行时间(秒)
|
||||
retry_times SMALLINT NOT NULL DEFAULT 3, -- 失败重试次数
|
||||
retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)
|
||||
auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理)
|
||||
remark TEXT DEFAULT '' -- 备注
|
||||
token_mapping VARCHAR(128) NOT NULL DEFAULT ''; -- token 映射
|
||||
);
|
||||
-- ========== 基础字段 ==========
|
||||
id BIGINT PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0,
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP(6),
|
||||
|
||||
-- ========== 模型标识 ==========
|
||||
model_name VARCHAR(128) NOT NULL,
|
||||
model_type SMALLINT NOT NULL DEFAULT 0,
|
||||
operator_name VARCHAR(64) NOT NULL DEFAULT '',
|
||||
|
||||
-- ========== 请求配置 ==========
|
||||
base_url VARCHAR(256) NOT NULL,
|
||||
http_method VARCHAR(8) NOT NULL DEFAULT 'POST',
|
||||
head_msg JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
api_key VARCHAR(256) NOT NULL DEFAULT '',
|
||||
|
||||
-- ========== 状态开关 ==========
|
||||
is_private SMALLINT NOT NULL DEFAULT 0,
|
||||
enabled SMALLINT NOT NULL DEFAULT 1,
|
||||
is_chat_model SMALLINT NOT NULL DEFAULT 0,
|
||||
is_async SMALLINT NOT NULL DEFAULT 0,
|
||||
is_stream SMALLINT NOT NULL DEFAULT 0,
|
||||
is_owner SMALLINT NOT NULL DEFAULT 99,
|
||||
|
||||
-- ========== 配置相关 ==========
|
||||
form_json JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
response_body JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
token_config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
query_config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
first_frame VARCHAR(128) NOT NULL DEFAULT '',
|
||||
last_frame VARCHAR(128) NOT NULL DEFAULT '',
|
||||
|
||||
-- ========== 限制与重试 ==========
|
||||
max_concurrency INT NOT NULL DEFAULT 10,
|
||||
timeout_seconds INT NOT NULL DEFAULT 600,
|
||||
retry_times SMALLINT NOT NULL DEFAULT 3,
|
||||
auto_clean_seconds INT NOT NULL DEFAULT 86400,
|
||||
|
||||
-- ========== 其他 ==========
|
||||
response_token_field VARCHAR(128) NOT NULL DEFAULT '',
|
||||
);
|
||||
|
||||
-- ========== 索引 ==========
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
|
||||
@@ -51,7 +67,9 @@ CREATE INDEX IF NOT EXISTS idx_asynch_models_model_type ON asynch_models(model_t
|
||||
CREATE INDEX IF NOT EXISTS idx_asynch_models_enabled ON asynch_models(enabled);
|
||||
CREATE INDEX IF NOT EXISTS idx_asynch_models_deleted_at ON asynch_models(deleted_at);
|
||||
|
||||
-- ========== 注释 ==========
|
||||
COMMENT ON TABLE asynch_models IS '模型配置表';
|
||||
|
||||
COMMENT ON COLUMN asynch_models.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN asynch_models.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN asynch_models.creator IS '创建人';
|
||||
@@ -62,29 +80,32 @@ COMMENT ON COLUMN asynch_models.deleted_at IS '删除时间(软删)';
|
||||
|
||||
COMMENT ON COLUMN asynch_models.model_name IS '模型名称';
|
||||
COMMENT ON COLUMN asynch_models.model_type IS '模型类型';
|
||||
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称';
|
||||
COMMENT ON COLUMN asynch_models.base_url IS '模型地址';
|
||||
COMMENT ON COLUMN asynch_models.http_method IS '请求方式 GET/POST';
|
||||
COMMENT ON COLUMN asynch_models.head_msg IS '请求头绑定(支持多个,逗号分隔)示例 X-API:xxx,operation:true';
|
||||
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化 0-私有 1-公共';
|
||||
COMMENT ON COLUMN asynch_models.enabled IS '是否启用 0停用 1-启用';
|
||||
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型 0-否 1-是';
|
||||
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建的,0=超级管理员的';
|
||||
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证,密钥';
|
||||
COMMENT ON COLUMN asynch_models.prompt IS '提示词内容(文本)';
|
||||
COMMENT ON COLUMN asynch_models.form_json IS '表单结构(用于前端渲染,也用于后端校验)';
|
||||
COMMENT ON COLUMN asynch_models.head_msg IS '请求头信息';
|
||||
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证/密钥';
|
||||
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化:0-私有 1-公共';
|
||||
COMMENT ON COLUMN asynch_models.enabled IS '是否启用:0-停用 1-启用';
|
||||
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型:0-否 1-是';
|
||||
COMMENT ON COLUMN asynch_models.is_async IS '是否异步:0-同步 1-异步';
|
||||
COMMENT ON COLUMN asynch_models.is_stream IS '是否流式:0-非流式 1-流式';
|
||||
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建 0=超级管理员';
|
||||
COMMENT ON COLUMN asynch_models.form_json IS '动态表单结构';
|
||||
COMMENT ON COLUMN asynch_models.request_mapping IS '请求映射';
|
||||
COMMENT ON COLUMN asynch_models.response_mapping IS '返回映射';
|
||||
COMMENT ON COLUMN asynch_models.response_body IS '返回主体';
|
||||
COMMENT ON COLUMN asynch_models.max_concurrency IS '单模型最大并发';
|
||||
COMMENT ON COLUMN asynch_models.queue_limit IS '排队上限(近似控制)';
|
||||
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型服务超时(秒)';
|
||||
COMMENT ON COLUMN asynch_models.expected_seconds IS '模型预计执行时间(秒)';
|
||||
COMMENT ON COLUMN asynch_models.token_config IS 'Token计算配置';
|
||||
COMMENT ON COLUMN asynch_models.extend_mapping IS '附加映射';
|
||||
COMMENT ON COLUMN asynch_models.query_config IS '查询/回调配置';
|
||||
COMMENT ON COLUMN asynch_models.stream_config IS '流式输出配置';
|
||||
COMMENT ON COLUMN asynch_models.first_frame IS '首帧图片参数';
|
||||
COMMENT ON COLUMN asynch_models.last_frame IS '尾帧图片参数';
|
||||
COMMENT ON COLUMN asynch_models.max_concurrency IS '最大并发数';
|
||||
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型超时(秒)';
|
||||
COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
|
||||
COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)';
|
||||
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)';
|
||||
COMMENT ON COLUMN asynch_models.remark IS '备注';
|
||||
COMMENT ON COLUMN asynch_models.token_mapping IS 'token映射';
|
||||
|
||||
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
|
||||
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
|
||||
|
||||
|
||||
-- =========================
|
||||
|
||||
Reference in New Issue
Block a user