refactor(asynch): 重构异步模型配置和队列管理

This commit is contained in:
2026-06-02 20:26:45 +08:00
parent c7e9eb889b
commit 52124385a1
18 changed files with 726 additions and 1006 deletions

View File

@@ -1,11 +1,17 @@
package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
@@ -18,14 +24,9 @@ import (
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentKey := ""
for k := range model.ResponseBody {
contentKey = k
break
}
contentStr, ok := raw[contentKey].(string)
contentStr, ok := raw[model.ResponseBody].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
}
// 2) 解析 content 为 JSON 数组
@@ -105,56 +106,39 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
return mapped, nil
}
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
// ParseHeadMsgHeaders head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func ParseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// PayloadToQuery 将 payload 转为 url.Values
func PayloadToQuery(payload map[string]any) (url.Values, error) {
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
@@ -164,3 +148,142 @@ func PayloadToQuery(payload map[string]any) (url.Values, error) {
}
return q, nil
}
// PullTaskResult 轮询查询任务结果直到完成
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
// 1. 解析配置
url := gconv.String(queryConfig["url"])
method := gconv.String(queryConfig["method"])
headers, _ := queryConfig["headers"].(map[string]any)
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
if method == "" {
method = "GET"
}
// 2. 构建参数
params := map[string]any{"id": taskID}
// 3. 替换 URL 中的 {id}
finalURL := replaceURLParams(url, params)
// 4. 构建请求体
bodyCfg, _ := queryConfig["body"].(map[string]any)
body := buildParams(bodyCfg, params)
// 5. 轮询
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var reqBody io.Reader
if method == "POST" && body != nil {
bs, _ := json.Marshal(body)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range headers {
req.Header.Set(k, gconv.String(v))
}
if req.Header.Get("Content-Type") == "" && reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
all, _ := io.ReadAll(resp.Body)
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
resp.Body.Close()
status := gconv.String(result["status"])
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
switch status {
case "succeeded":
return result, nil
case "failed", "expired":
return result, fmt.Errorf("任务失败: status=%s", status)
case "queued", "running":
time.Sleep(time.Duration(interval) * time.Second)
continue
default:
// 兼容没有 status 字段的情况,直接返回
return result, nil
}
}
}
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any, len(bodyCfg)+len(params))
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)\}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
return gconv.String(val)
}
return s
})
}
// replaceBodyParams 用 params 覆盖 body 中对应 key
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
return payload
}