package util import ( "bytes" "context" "encoding/json" "fmt" "io" "model-gateway/model/entity" "net/http" "net/url" "regexp" "strings" "time" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" tgjson "github.com/tidwall/gjson" ) // ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 // 校验逻辑:只校验 requestMapping 中默认值为空的必填字段 func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { // 1) 获取校验配置,并取值 requestMapping := model.RequestMapping contentStr, ok := raw[model.ResponseBody].(string) if !ok || contentStr == "" { return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody) } // 2) 解析 content 为 JSON 数组 var rounds []map[string]any if err := gjson.DecodeTo(contentStr, &rounds); err != nil { return fmt.Errorf("解析 content JSON 数组失败: %w", err) } if len(rounds) == 0 { return fmt.Errorf("content 数组为空") } // 3) 逐条校验:只检查默认值为空的必填字段是否存在 for i, round := range rounds { for path, defaultValue := range requestMapping { if !g.IsEmpty(defaultValue) { continue } if gjson.New(round).Get(path).IsNil() { return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path) } } } return nil } // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // head_msg 格式示例: // // { // "Authorization": "Bearer xxx", // "Content-Type": "application/json", // "X-Api-App-Id": "5147401364", // "X-Api-Access-Key": "VCqRX7..." // } func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string { if len(headMsg) == 0 { return nil } out := make(map[string]string, len(headMsg)) for k, v := range headMsg { out[k] = gconv.String(v) } return out } // MapResponsePayload 映射模型响应为标准格式 func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) { if len(mapping) == 0 { return result, nil } // 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入 resultBytes, _ := json.Marshal(result) resultStr := string(resultBytes) mapped := make(map[string]any) for standardField, modelPath := range mapping { path := gconv.String(modelPath) if path == "" { continue } value := tgjson.Get(resultStr, path) if !value.Exists() { continue } // 如果是数组路径(含 #),取 Array;否则取单值 if strings.Contains(path, "#") { var arr []any for _, v := range value.Array() { arr = append(arr, v.Value()) } mapped[standardField] = arr } else { mapped[standardField] = value.Value() } } return mapped, nil } // GetModelBody 获取数据库中保存的模型信息 func GetModelBody(v map[string]any) map[string]any { if v == nil { return nil } if p, ok := v["body"]; ok { return gconv.Map(p) } return v } // BodyToQuery 将 body 转为 url.Values func BodyToQuery(payload map[string]any) (url.Values, error) { q := url.Values{} for k, v := range payload { if v == nil { continue } q.Set(k, gconv.String(v)) } return q, nil } // PullTaskResult 轮询查询异步任务结果直到完成 func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) { // 1) 解析配置 // 1.1 提取 taskID taskIDPath := gconv.String(queryConfig["task_id"]) taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val()) if taskID == "" { return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath) } g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID) // 1.2 请求地址,替换 {id} queryUrl := gconv.String(queryConfig["url"]) queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID}) // 1.3 请求方式 method := gconv.String(queryConfig["method"]) if method == "" { method = "GET" } // 1.4 状态判断配置 statusPath := gconv.String(queryConfig["status_path"]) statusValues, _ := queryConfig["status_values"].(map[string]any) if statusPath == "" { statusPath = "status" } // 1.5 轮询间隔 interval := gconv.Int(queryConfig["interval_seconds"]) if interval <= 0 { interval = 2 } // 1.6 请求体 reqBodyMap := map[string]any{"task_id": taskID} // 2) 轮询请求 for { select { case <-ctx.Done(): return nil, ctx.Err() default: } var reqBody io.Reader if method == "POST" { bs, _ := json.Marshal(reqBodyMap) reqBody = bytes.NewReader(bs) } req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } // 统一用 headMsg 注入请求头 for hk, hv := range ParseHeadMsgHeaders(headMsg) { req.Header.Set(hk, hv) } client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err) time.Sleep(time.Duration(interval) * time.Second) continue } raw, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw)) if resp.StatusCode < 200 || resp.StatusCode >= 300 { time.Sleep(time.Duration(interval) * time.Second) continue } var result map[string]any _ = json.Unmarshal(raw, &result) statusVal := gjson.New(result).Get(statusPath).Val() statusStr := gconv.String(statusVal) g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal) if matchStatus(statusStr, statusValues["succeeded"]) { g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID) return result, nil } if matchStatus(statusStr, statusValues["failed"]) { g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID) return result, fmt.Errorf("任务失败") } time.Sleep(time.Duration(interval) * time.Second) } } func matchStatus(actual string, expected any) bool { expectedStr := gconv.String(expected) if actual == expectedStr { return true } switch v := expected.(type) { case []any: for _, item := range v { if actual == gconv.String(item) { return true } } } return false } // replaceURLParams 替换 URL 中的 {key} func replaceURLParams(url string, params map[string]any) string { re := regexp.MustCompile(`\{([^}]+)}`) return re.ReplaceAllStringFunc(url, func(s string) string { key := strings.Trim(s, "{}") if val, ok := params[key]; ok { return gconv.String(val) } return s }) } // InjectCallbackURL 将回调地址注入到请求体中 func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any { if callbackURL == "" { return payload } payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback") return payload }