package service import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "model-asynch/model/entity" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/frame/g" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) // parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: // 示例: // - X-API-Key:qwen3-tts-key,operation:true,count:123 // - X-API-Key:"qwen3-tts-key",operation:"true" // // 说明: // - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 // - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 func parseHeadMsgHeaders(headMsg string) map[string]string { headMsg = strings.TrimSpace(headMsg) if headMsg == "" { return nil } out := map[string]string{} parts := strings.Split(headMsg, ",") for _, p := range parts { p = strings.TrimSpace(p) if p == "" { continue } // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) if strings.Contains(p, ":") { kv := strings.SplitN(p, ":", 2) k := strings.TrimSpace(kv[0]) v := strings.TrimSpace(kv[1]) v = strings.Trim(v, "\"") if k != "" && v != "" { out[k] = v } continue } if strings.Contains(p, "=") { kv := strings.SplitN(p, "=", 2) k := strings.TrimSpace(kv[0]) v := strings.TrimSpace(kv[1]) v = strings.Trim(v, "\"") if k != "" && v != "" { out[k] = v } continue } } if len(out) == 0 { return nil } return out } func payloadToQuery(payload any) (url.Values, error) { if payload == nil { return url.Values{}, nil } // 统一转成 map[string]any b, err := json.Marshal(payload) if err != nil { return nil, err } m := map[string]any{} if err := json.Unmarshal(b, &m); err != nil { return nil, err } q := url.Values{} for k, v := range m { if v == nil { continue } // 复杂类型直接 json 字符串化 switch vv := v.(type) { case string: q.Set(k, vv) case float64, bool, int, int64, uint64: q.Set(k, fmt.Sprintf("%v", vv)) default: bs, _ := json.Marshal(v) q.Set(k, string(bs)) } } return q, nil } // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。 func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { if m == nil || m.BaseURL == "" { return nil, fmt.Errorf("模型配置不完整") } // ============ 新增:请求参数映射 ============ mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) if err != nil { return nil, fmt.Errorf("请求参数映射失败: %w", err) } url := strings.TrimRight(m.BaseURL, "/") timeout := time.Duration(m.TimeoutSeconds) * time.Second if timeout <= 0 { timeout = 60 * time.Second } client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) if method == "" { method = http.MethodPost } var ( req *http.Request ) switch method { case http.MethodGet: q, err := payloadToQuery(mappedPayload) // 使用映射后的payload if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(url, "?") { url = url + "&" + q.Encode() } else { url = url + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil) default: bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) } if err != nil { return nil, err } // 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key) for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { req.Header.Set(hk, hv) } // 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。 for hk, hv := range parseHeadMsgHeaders(modelKey) { req.Header.Set(hk, hv) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg := string(b) if len(msg) > 2000 { msg = msg[:2000] } return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } // ============ 新增:响应参数映射 ============ mappedResponse, err := mapResponsePayload(m.ResponseMapping, b) if err != nil { // 响应映射失败不阻塞,返回原始数据 g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) return b, nil } // ========================================= return mappedResponse, nil } // ============================================ // 映射相关函数 // ============================================ // mapRequestPayload 将标准请求映射为模型特定格式 func mapRequestPayload(mappingAny any, payload any) (any, error) { // 1. 解析请求映射配置(值是any类型,支持bool、number等) mapping, err := parseRequestMapping(mappingAny) if err != nil { return nil, err } // 如果没有映射配置,直接返回原始payload if len(mapping) == 0 { return payload, nil } // 2. 将payload转为map var payloadMap map[string]any switch v := payload.(type) { case map[string]any: payloadMap = v case []map[string]any: // 如果传进来的是纯messages数组,包装成标准格式 payloadMap = map[string]any{ "messages": v, } default: // 通过JSON转换 jsonBytes, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("序列化payload失败: %w", err) } if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil { return nil, fmt.Errorf("反序列化payload失败: %w", err) } } // 3. 用数据库固定参数覆盖/补充 for key, value := range mapping { if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) { payloadMap[key] = value } } return payloadMap, nil } // mapResponsePayload 将模型响应映射为标准格式 func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) { mapping, err := parseResponseMapping(mappingAny) if err != nil { return nil, err } if len(mapping) == 0 { return responseBytes, nil } responseStr := string(responseBytes) resultStr := `{}` for standardField, modelPath := range mapping { value := gjson.Get(responseStr, modelPath) if !value.Exists() { continue } resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw) if err != nil { return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err) } } return []byte(resultStr), nil } func parseRequestMapping(mappingAny any) (map[string]any, error) { if mappingAny == nil { return nil, nil } result := make(map[string]any) switch v := mappingAny.(type) { case *gvar.Var: if v == nil || v.IsNil() || v.IsEmpty() { return nil, nil } // 尝试转成 map if m := v.Map(); m != nil { for k, val := range m { result[k] = val } return result, nil } // 尝试转成 string if s := v.String(); s != "" && s != "{}" && s != "null" { if err := json.Unmarshal([]byte(s), &result); err != nil { return nil, fmt.Errorf("解析请求映射字符串失败: %w", err) } return result, nil } return nil, nil // ======================================================= case map[string]interface{}: result = v case string: if v == "" || v == "{}" || v == "null" { return nil, nil } if err := json.Unmarshal([]byte(v), &result); err != nil { return nil, fmt.Errorf("解析请求映射字符串失败: %w", err) } case []byte: if len(v) == 0 { return nil, nil } if err := json.Unmarshal(v, &result); err != nil { return nil, fmt.Errorf("解析请求映射字节失败: %w", err) } default: jsonBytes, err := json.Marshal(mappingAny) if err != nil { return nil, fmt.Errorf("序列化映射配置失败: %w", err) } if err := json.Unmarshal(jsonBytes, &result); err != nil { return nil, fmt.Errorf("解析映射配置失败: %w", err) } } return result, nil } // parseResponseMapping 解析响应映射配置 // 返回值类型为 map[string]string,值都是JSON路径字符串 func parseResponseMapping(mappingAny any) (map[string]string, error) { if mappingAny == nil { return nil, nil } mapping := make(map[string]string) switch v := mappingAny.(type) { case *gvar.Var: if v == nil || v.IsNil() || v.IsEmpty() { return nil, nil } if m := v.Map(); m != nil { for k, val := range m { if strVal, ok := val.(string); ok { mapping[k] = strVal } } return mapping, nil } if s := v.String(); s != "" && s != "{}" && s != "null" { if err := json.Unmarshal([]byte(s), &mapping); err != nil { return nil, fmt.Errorf("解析响应映射字符串失败: %w", err) } return mapping, nil } return nil, nil case string: if v == "" || v == "{}" || v == "null" { return nil, nil } if err := json.Unmarshal([]byte(v), &mapping); err != nil { return nil, fmt.Errorf("解析响应映射字符串失败: %w", err) } case map[string]interface{}: // 数据库JSONB直接返回的map for k, val := range v { if strVal, ok := val.(string); ok { mapping[k] = strVal } } case []byte: if len(v) == 0 { return nil, nil } if err := json.Unmarshal(v, &mapping); err != nil { return nil, fmt.Errorf("解析响应映射字节失败: %w", err) } default: jsonBytes, err := json.Marshal(mappingAny) if err != nil { return nil, fmt.Errorf("序列化响应映射配置失败: %w", err) } if err := json.Unmarshal(jsonBytes, &mapping); err != nil { return nil, fmt.Errorf("解析响应映射配置失败: %w", err) } } return mapping, nil } // isEmptyValue 判断值是否为空 func isEmptyValue(v any) bool { if v == nil { return true } switch val := v.(type) { case string: return val == "" case []any: return len(val) == 0 case map[string]any: return len(val) == 0 default: return false } }