package service import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "model-asynch/model/entity" ) // parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: // 示例: // - X-API-Key:qwen3-tts-key,operation:true,count:123 // - X-API-Key:"qwen3-tts-key",operation:"true" // // 说明: // - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 // - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 func parseHeadMsgHeaders(headMsg string) map[string]string { headMsg = strings.TrimSpace(headMsg) if headMsg == "" { return nil } out := map[string]string{} parts := strings.Split(headMsg, ",") for _, p := range parts { p = strings.TrimSpace(p) if p == "" { continue } // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) if strings.Contains(p, ":") { kv := strings.SplitN(p, ":", 2) k := strings.TrimSpace(kv[0]) v := strings.TrimSpace(kv[1]) v = strings.Trim(v, "\"") if k != "" && v != "" { out[k] = v } continue } if strings.Contains(p, "=") { kv := strings.SplitN(p, "=", 2) k := strings.TrimSpace(kv[0]) v := strings.TrimSpace(kv[1]) v = strings.Trim(v, "\"") if k != "" && v != "" { out[k] = v } continue } } if len(out) == 0 { return nil } return out } func payloadToQuery(payload any) (url.Values, error) { if payload == nil { return url.Values{}, nil } // 统一转成 map[string]any b, err := json.Marshal(payload) if err != nil { return nil, err } m := map[string]any{} if err := json.Unmarshal(b, &m); err != nil { return nil, err } q := url.Values{} for k, v := range m { if v == nil { continue } // 复杂类型直接 json 字符串化 switch vv := v.(type) { case string: q.Set(k, vv) case float64, bool, int, int64, uint64: q.Set(k, fmt.Sprintf("%v", vv)) default: bs, _ := json.Marshal(v) q.Set(k, string(bs)) } } return q, nil } // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。 func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { if m == nil || m.BaseURL == "" { return nil, fmt.Errorf("模型配置不完整") } url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/") if strings.TrimSpace(m.Route) == "" { url = strings.TrimRight(m.BaseURL, "/") } timeout := time.Duration(m.TimeoutSeconds) * time.Second if timeout <= 0 { timeout = 60 * time.Second } client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) if method == "" { method = http.MethodPost } var ( req *http.Request err error ) switch method { case http.MethodGet: q, err := payloadToQuery(payload) if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(url, "?") { url = url + "&" + q.Encode() } else { url = url + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil) default: bodyBytes, err := json.Marshal(payload) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) } if err != nil { return nil, err } // 先注入模型配置 head_msg(静态头部) for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { req.Header.Set(hk, hv) } // 透传必要头部(如 Authorization / X-User-Info) for k, v := range forwardHeaders(ctx) { if v != "" { req.Header.Set(k, v) } } // 最后注入动态 modelKey(覆盖/补充静态 head_msg) for hk, hv := range parseHeadMsgHeaders(modelKey) { req.Header.Set(hk, hv) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { // 尽量把错误体带回去,方便排查 msg := string(b) if len(msg) > 2000 { msg = msg[:2000] } return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } return b, nil }