package util import ( "bytes" "context" "encoding/json" "fmt" "io" "model-gateway/model/entity" "net/http" "net/url" "regexp" "strings" "time" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" tgjson "github.com/tidwall/gjson" ) // ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 // 校验逻辑:只校验 requestMapping 中默认值为空的必填字段 func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { // 1) 获取校验配置,并取值 requestMapping := model.RequestMapping contentStr, ok := raw[model.ResponseBody].(string) if !ok || contentStr == "" { return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody) } // 2) 解析 content 为 JSON 数组 var rounds []map[string]any if err := gjson.DecodeTo(contentStr, &rounds); err != nil { return fmt.Errorf("解析 content JSON 数组失败: %w", err) } if len(rounds) == 0 { return fmt.Errorf("content 数组为空") } // 3) 逐条校验:只检查默认值为空的必填字段是否存在 for i, round := range rounds { for path, defaultValue := range requestMapping { if !g.IsEmpty(defaultValue) { continue } if gjson.New(round).Get(path).IsNil() { return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path) } } } return nil } // ReverseMap 映射 payload 到 mapping func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { jsonObj := gjson.New("{}") for path, defaultValue := range mapping { // 从 payload 取对应路径的值 val := gjson.New(payload).Get(path) if !val.IsNil() { // payload 有值,用它 _ = jsonObj.Set(path, val.Val()) } else if !g.IsEmpty(defaultValue) { // payload 没值,用默认值 _ = jsonObj.Set(path, defaultValue) } } return jsonObj.Map() } // MapResponsePayload 映射模型响应为标准格式 func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) { if len(mapping) == 0 { return result, nil } // 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入 resultBytes, _ := json.Marshal(result) resultStr := string(resultBytes) mapped := make(map[string]any) for standardField, modelPath := range mapping { path := gconv.String(modelPath) if path == "" { continue } value := tgjson.Get(resultStr, path) if !value.Exists() { continue } // 如果是数组路径(含 #),取 Array;否则取单值 if strings.Contains(path, "#") { var arr []any for _, v := range value.Array() { arr = append(arr, v.Value()) } mapped[standardField] = arr } else { mapped[standardField] = value.Value() } } return mapped, nil } // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // head_msg 格式示例: // // { // "Authorization": "Bearer xxx", // "Content-Type": "application/json", // "X-Api-App-Id": "5147401364", // "X-Api-Access-Key": "VCqRX7..." // } func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string { if len(headMsg) == 0 { return nil } out := make(map[string]string, len(headMsg)) for k, v := range headMsg { out[k] = gconv.String(v) } return out } // GetModelBody 获取数据库中保存的模型信息 func GetModelBody(v map[string]any) map[string]any { if v == nil { return nil } if p, ok := v["body"]; ok { return gconv.Map(p) } return v } // BodyToQuery 将 body 转为 url.Values func BodyToQuery(payload map[string]any) (url.Values, error) { q := url.Values{} for k, v := range payload { if v == nil { continue } q.Set(k, gconv.String(v)) } return q, nil } // PullTaskResult 轮询查询任务结果直到完成 func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) { // 1. 解析配置 url := gconv.String(queryConfig["url"]) method := gconv.String(queryConfig["method"]) headers, _ := queryConfig["headers"].(map[string]any) interval := gconv.Int(queryConfig["interval_seconds"]) if interval <= 0 { interval = 2 } if method == "" { method = "GET" } // 2. 构建参数 params := map[string]any{"id": taskID} // 3. 替换 URL 中的 {id} finalURL := replaceURLParams(url, params) // 4. 构建请求体 bodyCfg, _ := queryConfig["body"].(map[string]any) body := buildParams(bodyCfg, params) // 5. 轮询 for { select { case <-ctx.Done(): return nil, ctx.Err() default: } var reqBody io.Reader if method == "POST" && body != nil { bs, _ := json.Marshal(body) reqBody = bytes.NewReader(bs) } req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } for k, v := range headers { req.Header.Set(k, gconv.String(v)) } if req.Header.Get("Content-Type") == "" && reqBody != nil { req.Header.Set("Content-Type", "application/json") } client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err) time.Sleep(time.Duration(interval) * time.Second) continue } if resp.StatusCode < 200 || resp.StatusCode >= 300 { all, _ := io.ReadAll(resp.Body) resp.Body.Close() g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all)) time.Sleep(time.Duration(interval) * time.Second) continue } var result map[string]any if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { resp.Body.Close() g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err) time.Sleep(time.Duration(interval) * time.Second) continue } resp.Body.Close() status := gconv.String(result["status"]) g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status) switch status { case "succeeded": return result, nil case "failed", "expired": return result, fmt.Errorf("任务失败: status=%s", status) case "queued", "running": time.Sleep(time.Duration(interval) * time.Second) continue default: // 兼容没有 status 字段的情况,直接返回 return result, nil } } } // buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any { result := make(map[string]any, len(bodyCfg)+len(params)) for k, v := range bodyCfg { result[k] = v } for k, v := range params { result[k] = v } return result } // replaceURLParams 替换 URL 中的 {key} func replaceURLParams(url string, params map[string]any) string { re := regexp.MustCompile(`\{([^}]+)\}`) return re.ReplaceAllStringFunc(url, func(s string) string { key := strings.Trim(s, "{}") if val, ok := params[key]; ok { return gconv.String(val) } return s }) } // replaceBodyParams 用 params 覆盖 body 中对应 key func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any { result := make(map[string]any) for k, v := range bodyCfg { result[k] = v } for k, v := range params { result[k] = v } return result } // InjectCallbackURL 将回调地址注入到请求体中 func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any { if callbackURL == "" { return payload } payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback") return payload }