package util import ( "bytes" "context" "encoding/json" "fmt" "io" "model-gateway/model/entity" "net/http" "net/url" "regexp" "strings" "time" "gitea.redpowerfuture.com/red-future/common/utils" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" tgjson "github.com/tidwall/gjson" ) // ParseAndValidate 解析并校验结果 func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) { // 1) 解析 content 字符串为 rounds 数组 contentVal, ok := raw[model.ResponseBody] if !ok { return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody) } contentStr, ok := contentVal.(string) if !ok || strings.TrimSpace(contentStr) == "" { return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody) } var arr []any if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { return raw, fmt.Errorf("JSON解析失败: %w", err) } if len(arr) == 0 { return raw, fmt.Errorf("解析后数组为空") } // 2) 校验必填字段 if len(model.RequiredFields) > 0 { for i, r := range arr { round, ok := r.(map[string]any) if !ok { continue } for _, field := range model.RequiredFields { if gjson.New(round).Get(field).IsNil() { return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) } } } } return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil } // ParseStructResult 解析结构结果 func ParseStructResult(raw map[string]any, responseBody string) map[string]any { contentVal := raw[responseBody] // 是字符串,尝试解析 contentStr := gconv.String(contentVal) if contentStr == "" || contentStr == "0" { return map[string]any{ "total_rounds": 1, "rounds": []map[string]any{{responseBody: raw}}, } } // 尝试解析为数组 var arr []any if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 { return map[string]any{ "total_rounds": 1, "rounds": []map[string]any{{responseBody: arr}}, } } // 尝试解析为单个对象 var parsed any if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil { return map[string]any{ "total_rounds": 1, "rounds": []map[string]any{{responseBody: parsed}}, } } // 兜底:原始字符串作为内容 return map[string]any{ "total_rounds": 1, "rounds": []map[string]any{{responseBody: contentStr}}, } } // ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 // raw 必须包含 "rounds" 字段,格式为 []map[string]any func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { // 1) 获取 rounds roundsRaw, ok := raw["rounds"] if !ok { return fmt.Errorf("缺少 rounds 字段") } rounds, ok := roundsRaw.([]any) if !ok { return fmt.Errorf("rounds 不是数组") } if len(rounds) == 0 { return fmt.Errorf("rounds 数组为空") } // 2) 没有配置必填字段,跳过 if len(model.RequiredFields) == 0 { return nil } // 3) 逐条校验 for i, r := range rounds { round, ok := r.(map[string]any) if !ok { continue } for _, field := range model.RequiredFields { if gjson.New(round).Get(field).IsNil() { return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) } } } return nil } // validateRequiredFields 校验单个 round 对象的必选字段 func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error { for _, field := range requiredFields { if gjson.New(round).Get(field).IsNil() { return fmt.Errorf("%s 缺少必填字段: %s", prefix, field) } } return nil } // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // head_msg 格式示例: // // { // "Authorization": "Bearer xxx", // "Content-Type": "application/json", // "X-Api-App-Id": "5147401364", // "X-Api-Access-Key": "VCqRX7..." // } func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string { if len(headMsg) == 0 { return nil } out := make(map[string]string, len(headMsg)) for k, v := range headMsg { out[k] = gconv.String(v) } return out } // MapResponsePayload 映射模型响应为标准格式 func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) { if len(mapping) == 0 { return result, nil } // 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入 resultBytes, _ := json.Marshal(result) resultStr := string(resultBytes) mapped := make(map[string]any) for standardField, modelPath := range mapping { path := gconv.String(modelPath) if path == "" { continue } value := tgjson.Get(resultStr, path) if !value.Exists() { continue } // 如果是数组路径(含 #),取 Array;否则取单值 if strings.Contains(path, "#") { var arr []any for _, v := range value.Array() { arr = append(arr, v.Value()) } mapped[standardField] = arr } else { mapped[standardField] = value.Value() } } return mapped, nil } // GetModelBody 获取数据库中保存的模型信息 func GetModelBody(v map[string]any) map[string]any { if v == nil { return nil } if p, ok := v["body"]; ok { return gconv.Map(p) } return v } // BodyToQuery 将 body 转为 url.Values func BodyToQuery(payload map[string]any) (url.Values, error) { q := url.Values{} for k, v := range payload { if v == nil { continue } q.Set(k, gconv.String(v)) } return q, nil } // PullTaskResult 轮询查询异步任务结果直到完成 func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) { // 1) 解析配置 // 1.1 提取 taskID taskIDPath := gconv.String(queryConfig["task_id"]) taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val()) if taskID == "" { return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath) } g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID) // 1.2 请求地址,替换 {id} queryUrl := gconv.String(queryConfig["url"]) queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID}) // 1.3 请求方式 method := gconv.String(queryConfig["method"]) if method == "" { method = "GET" } // 1.4 状态判断配置 statusPath := gconv.String(queryConfig["status_path"]) statusValues, _ := queryConfig["status_values"].(map[string]any) if statusPath == "" { statusPath = "status" } // 1.5 轮询间隔 interval := gconv.Int(queryConfig["interval_seconds"]) if interval <= 0 { interval = 2 } // 1.6 请求体 reqBodyMap := map[string]any{"task_id": taskID} // 2) 轮询请求 for { select { case <-ctx.Done(): return nil, ctx.Err() default: } var reqBody io.Reader if method == "POST" { bs, _ := json.Marshal(reqBodyMap) reqBody = bytes.NewReader(bs) } req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } // 统一用 headMsg 注入请求头 for hk, hv := range ParseHeadMsgHeaders(headMsg) { req.Header.Set(hk, hv) } client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err) time.Sleep(time.Duration(interval) * time.Second) continue } raw, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw)) if resp.StatusCode < 200 || resp.StatusCode >= 300 { time.Sleep(time.Duration(interval) * time.Second) continue } var result map[string]any _ = json.Unmarshal(raw, &result) statusVal := gjson.New(result).Get(statusPath).Val() statusStr := gconv.String(statusVal) g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal) if matchStatus(statusStr, statusValues["succeeded"]) { g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID) return result, nil } if matchStatus(statusStr, statusValues["failed"]) { g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID) return result, fmt.Errorf("任务失败") } time.Sleep(time.Duration(interval) * time.Second) } } func matchStatus(actual string, expected any) bool { expectedStr := gconv.String(expected) if actual == expectedStr { return true } switch v := expected.(type) { case []any: for _, item := range v { if actual == gconv.String(item) { return true } } } return false } // replaceURLParams 替换 URL 中的 {key} func replaceURLParams(url string, params map[string]any) string { re := regexp.MustCompile(`\{([^}]+)}`) return re.ReplaceAllStringFunc(url, func(s string) string { key := strings.Trim(s, "{}") if val, ok := params[key]; ok { return gconv.String(val) } return s }) } // InjectCallbackURL 将回调地址注入到请求体中 func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any { if callbackURL == "" { return payload } payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback") return payload }