refactor(asynch): 重构异步模型配置和队列管理
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -67,3 +68,48 @@ func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// SaveTempFileByType
|
||||
// 根据传入的数据自动判断:
|
||||
// 若是 []byte 且后缀为 .mp3 → 保存二进制音频
|
||||
// 若是任意结构体/map → 自动转 JSON 保存
|
||||
// 返回:新临时文件路径、错误
|
||||
func SaveTempFileByType(taskID string, data any, oldTmpFile string) (string, error) {
|
||||
// 1. 先清理旧临时文件(统一逻辑)
|
||||
if oldTmpFile != "" {
|
||||
_ = os.Remove(oldTmpFile)
|
||||
}
|
||||
|
||||
var tmpPath string
|
||||
var tmpErr error
|
||||
|
||||
// 2. 判断是否是二进制音频([]byte + .mp3)
|
||||
if audioData, ok := data.([]byte); ok {
|
||||
tmpPath, tmpErr = saveTmpResult(taskID, audioData, ".mp3")
|
||||
} else {
|
||||
// 3. 其他类型 → 序列化为 JSON 保存
|
||||
mappedBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(mappedBytes) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
tmpPath, tmpErr = saveTmpResult(taskID, mappedBytes, ".json")
|
||||
}
|
||||
|
||||
if tmpErr != nil || tmpPath == "" {
|
||||
return "", tmpErr
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// saveTmpResult 你原有的底层保存文件方法(保留不动)
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
// 你原来实现,比如:
|
||||
filename := taskID + ext
|
||||
tmpPath := filepath.Join(os.TempDir(), filename)
|
||||
err := os.WriteFile(tmpPath, data, 0644)
|
||||
return tmpPath, err
|
||||
}
|
||||
|
||||
@@ -77,14 +77,3 @@ func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ParseStoredPayload 解析入库的 request_payload,拆出模型调用核心数据
|
||||
func ParseStoredPayload(v map[string]any) map[string]any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if p, ok := v["payload"]; ok {
|
||||
return gconv.Map(p)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"model-gateway/model/entity"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -18,14 +24,9 @@ import (
|
||||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||
// 1) 获取校验配置,并取值
|
||||
requestMapping := model.RequestMapping
|
||||
contentKey := ""
|
||||
for k := range model.ResponseBody {
|
||||
contentKey = k
|
||||
break
|
||||
}
|
||||
contentStr, ok := raw[contentKey].(string)
|
||||
contentStr, ok := raw[model.ResponseBody].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
|
||||
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
|
||||
}
|
||||
|
||||
// 2) 解析 content 为 JSON 数组
|
||||
@@ -105,56 +106,39 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
|
||||
return mapped, nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
||||
// head_msg 格式示例:
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
// {
|
||||
// "Authorization": "Bearer xxx",
|
||||
// "Content-Type": "application/json",
|
||||
// "X-Api-App-Id": "5147401364",
|
||||
// "X-Api-Access-Key": "VCqRX7..."
|
||||
// }
|
||||
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
||||
if len(headMsg) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
out := make(map[string]string, len(headMsg))
|
||||
for k, v := range headMsg {
|
||||
out[k] = gconv.String(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PayloadToQuery 将 payload 转为 url.Values
|
||||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
// GetModelBody 获取数据库中保存的模型信息
|
||||
func GetModelBody(v map[string]any) map[string]any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if p, ok := v["body"]; ok {
|
||||
return gconv.Map(p)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// BodyToQuery 将 body 转为 url.Values
|
||||
func BodyToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
@@ -164,3 +148,142 @@ func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// PullTaskResult 轮询查询任务结果直到完成
|
||||
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
|
||||
// 1. 解析配置
|
||||
url := gconv.String(queryConfig["url"])
|
||||
method := gconv.String(queryConfig["method"])
|
||||
headers, _ := queryConfig["headers"].(map[string]any)
|
||||
interval := gconv.Int(queryConfig["interval_seconds"])
|
||||
if interval <= 0 {
|
||||
interval = 2
|
||||
}
|
||||
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
// 2. 构建参数
|
||||
params := map[string]any{"id": taskID}
|
||||
|
||||
// 3. 替换 URL 中的 {id}
|
||||
finalURL := replaceURLParams(url, params)
|
||||
|
||||
// 4. 构建请求体
|
||||
bodyCfg, _ := queryConfig["body"].(map[string]any)
|
||||
body := buildParams(bodyCfg, params)
|
||||
|
||||
// 5. 轮询
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var reqBody io.Reader
|
||||
if method == "POST" && body != nil {
|
||||
bs, _ := json.Marshal(body)
|
||||
reqBody = bytes.NewReader(bs)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, gconv.String(v))
|
||||
}
|
||||
if req.Header.Get("Content-Type") == "" && reqBody != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
all, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
resp.Body.Close()
|
||||
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
status := gconv.String(result["status"])
|
||||
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
|
||||
|
||||
switch status {
|
||||
case "succeeded":
|
||||
return result, nil
|
||||
case "failed", "expired":
|
||||
return result, fmt.Errorf("任务失败: status=%s", status)
|
||||
case "queued", "running":
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
continue
|
||||
default:
|
||||
// 兼容没有 status 字段的情况,直接返回
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
|
||||
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
||||
result := make(map[string]any, len(bodyCfg)+len(params))
|
||||
for k, v := range bodyCfg {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range params {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// replaceURLParams 替换 URL 中的 {key}
|
||||
func replaceURLParams(url string, params map[string]any) string {
|
||||
re := regexp.MustCompile(`\{([^}]+)\}`)
|
||||
return re.ReplaceAllStringFunc(url, func(s string) string {
|
||||
key := strings.Trim(s, "{}")
|
||||
if val, ok := params[key]; ok {
|
||||
return gconv.String(val)
|
||||
}
|
||||
return s
|
||||
})
|
||||
}
|
||||
|
||||
// replaceBodyParams 用 params 覆盖 body 中对应 key
|
||||
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
||||
result := make(map[string]any)
|
||||
for k, v := range bodyCfg {
|
||||
result[k] = v
|
||||
}
|
||||
for k, v := range params {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// InjectCallbackURL 将回调地址注入到请求体中
|
||||
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
||||
if callbackURL == "" {
|
||||
return payload
|
||||
}
|
||||
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
|
||||
return payload
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user