refactor(util): 重构映射工具函数并优化异步任务轮询逻辑

This commit is contained in:
2026-06-03 13:30:39 +08:00
parent 2c7838807b
commit bcfcc7ed47
6 changed files with 99 additions and 131 deletions

View File

@@ -52,21 +52,24 @@ func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
// 从 payload 取对应路径的值
val := gjson.New(payload).Get(path)
if !val.IsNil() {
// payload 有值,用它
_ = jsonObj.Set(path, val.Val())
} else if !g.IsEmpty(defaultValue) {
// payload 没值,用默认值
_ = jsonObj.Set(path, defaultValue)
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
return jsonObj.Map()
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// MapResponsePayload 映射模型响应为标准格式
@@ -106,26 +109,6 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
return mapped, nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
@@ -149,32 +132,44 @@ func BodyToQuery(payload map[string]any) (url.Values, error) {
return q, nil
}
// PullTaskResult 轮询查询任务结果直到完成
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
// 1. 解析配置
url := gconv.String(queryConfig["url"])
// PullTaskResult 轮询查询异步任务结果直到完成
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
// 1) 解析配置
// 1.1 提取 taskID
taskIDPath := gconv.String(queryConfig["task_id"])
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
if taskID == "" {
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
}
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
// 1.2 请求地址,替换 {id}
queryUrl := gconv.String(queryConfig["url"])
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
// 1.3 请求方式
method := gconv.String(queryConfig["method"])
headers, _ := queryConfig["headers"].(map[string]any)
if method == "" {
method = "GET"
}
// 1.4 状态判断配置
statusPath := gconv.String(queryConfig["status_path"])
statusValues, _ := queryConfig["status_values"].(map[string]any)
if statusPath == "" {
statusPath = "status"
}
// 1.5 轮询间隔
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
if method == "" {
method = "GET"
}
// 1.6 请求体
reqBodyMap := map[string]any{"task_id": taskID}
// 2. 构建参数
params := map[string]any{"id": taskID}
// 3. 替换 URL 中的 {id}
finalURL := replaceURLParams(url, params)
// 4. 构建请求体
bodyCfg, _ := queryConfig["body"].(map[string]any)
body := buildParams(bodyCfg, params)
// 5. 轮询
// 2) 轮询请求
for {
select {
case <-ctx.Done():
@@ -183,21 +178,19 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a
}
var reqBody io.Reader
if method == "POST" && body != nil {
bs, _ := json.Marshal(body)
if method == "POST" {
bs, _ := json.Marshal(reqBodyMap)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range headers {
req.Header.Set(k, gconv.String(v))
}
if req.Header.Get("Content-Type") == "" && reqBody != nil {
req.Header.Set("Content-Type", "application/json")
// 统一用 headMsg 注入请求头
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
req.Header.Set(hk, hv)
}
client := &http.Client{Timeout: 30 * time.Second}
@@ -208,56 +201,54 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a
continue
}
raw, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
all, _ := io.ReadAll(resp.Body)
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
resp.Body.Close()
_ = json.Unmarshal(raw, &result)
status := gconv.String(result["status"])
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
statusVal := gjson.New(result).Get(statusPath).Val()
statusStr := gconv.String(statusVal)
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
switch status {
case "succeeded":
return result, nil
case "failed", "expired":
return result, fmt.Errorf("任务失败: status=%s", status)
case "queued", "running":
time.Sleep(time.Duration(interval) * time.Second)
continue
default:
// 兼容没有 status 字段的情况,直接返回
if matchStatus(statusStr, statusValues["succeeded"]) {
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
return result, nil
}
if matchStatus(statusStr, statusValues["failed"]) {
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
return result, fmt.Errorf("任务失败")
}
time.Sleep(time.Duration(interval) * time.Second)
}
}
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any, len(bodyCfg)+len(params))
for k, v := range bodyCfg {
result[k] = v
func matchStatus(actual string, expected any) bool {
switch v := expected.(type) {
case string:
return actual == v
case []any:
for _, item := range v {
if actual == gconv.String(item) {
return true
}
}
}
for k, v := range params {
result[k] = v
}
return result
return false
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)\}`)
re := regexp.MustCompile(`\{([^}]+)}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
@@ -267,18 +258,6 @@ func replaceURLParams(url string, params map[string]any) string {
})
}
// replaceBodyParams 用 params 覆盖 body 中对应 key
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {