Files
data-engine/service/sync/api_client.go
2026-05-29 18:39:32 +08:00

253 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sync
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"strings"
"time"
"github.com/sirupsen/logrus"
)
// ApiResult API 调用结果
type ApiResult struct {
Body []byte
DurationMs int64
}
// ApiClient 通用 API 客户端
type ApiClient struct {
config *PlatformConfig
client *http.Client
}
// NewApiClient 创建客户端
func NewApiClient(config *PlatformConfig) *ApiClient {
timeout := 30 * time.Second
if config.RequestTimeoutMs > 0 {
timeout = time.Duration(config.RequestTimeoutMs) * time.Millisecond
}
return &ApiClient{
config: config,
client: &http.Client{Timeout: timeout},
}
}
// Get 发送 GET 请求(无参数)
func (c *ApiClient) Get(ctx context.Context, path string) (*ApiResult, error) {
return c.doRequest(ctx, "GET", path, nil, false)
}
// PostJSON 发送 POST JSON 请求
func (c *ApiClient) PostJSON(ctx context.Context, path string, body interface{}) (*ApiResult, error) {
return c.doRequest(ctx, "POST", path, body, false)
}
// Request 通用请求方法(支持 GET/POST支持参数在 query 或 body
func (c *ApiClient) Request(ctx context.Context, method, path string, params map[string]interface{}, paramsInQuery bool) (*ApiResult, error) {
if paramsInQuery {
return c.doRequest(ctx, method, path, params, true)
}
if method == "GET" {
return c.doRequest(ctx, "GET", path, params, true)
}
return c.doRequest(ctx, method, path, params, false)
}
func (c *ApiClient) doRequest(ctx context.Context, method, path string, body interface{}, paramsInQuery bool) (result *ApiResult, err error) {
maxRetries := c.config.MaxRetries
if maxRetries <= 0 {
maxRetries = 3
}
retryDelay := time.Duration(c.config.RetryDelayMs) * time.Millisecond
if retryDelay <= 0 {
retryDelay = 1 * time.Second
}
for attempt := 0; attempt <= maxRetries; attempt++ {
result, err = c.execute(ctx, method, path, body, paramsInQuery)
if err == nil {
return result, nil
}
logrus.Warnf("请求失败 (attempt %d/%d): %v", attempt+1, maxRetries+1, err)
if attempt < maxRetries {
time.Sleep(retryDelay * time.Duration(attempt+1))
}
}
return result, fmt.Errorf("请求已重试 %d 次仍失败: %w", maxRetries, err)
}
func (c *ApiClient) execute(ctx context.Context, method, path string, body interface{}, paramsInQuery bool) (*ApiResult, error) {
start := time.Now()
fullURL := c.config.GetApiUrl(path)
// 先注入认证参数
fullURL = c.applyAuthURL(fullURL)
var reqBody io.Reader
if body != nil && !paramsInQuery {
b, _ := json.Marshal(body)
reqBody = bytes.NewBuffer(b)
}
// 如果参数在查询字符串中,拼接到 URL
if body != nil && paramsInQuery {
if paramsMap, ok := body.(map[string]interface{}); ok {
fullURL = c.buildQueryURL(fullURL, paramsMap)
}
}
logrus.Infof("请求 URL: %s", fullURL)
req, err := http.NewRequestWithContext(ctx, method, fullURL, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
c.applyAuthHeader(req)
req.Header.Set("User-Agent", "data-engine/1.0")
if body != nil && !paramsInQuery {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
result := &ApiResult{Body: respBody, DurationMs: time.Since(start).Milliseconds()}
if resp.StatusCode >= 400 {
return result, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
return result, nil
}
// buildQueryURL 将 params 拼接到 URL 查询参数中
// 支持数组/对象类型的值自动 JSON 序列化 + URL 编码
func (c *ApiClient) buildQueryURL(rawURL string, params map[string]interface{}) string {
parsed, _ := url.Parse(rawURL)
q := parsed.Query()
for k, v := range params {
switch val := v.(type) {
case string:
q.Set(k, val)
case bool:
if val {
q.Set(k, "true")
} else {
q.Set(k, "false")
}
case float64:
// JSON 数字反序列化默认是 float64转 int 避免科学计数法
if val == float64(int64(val)) {
q.Set(k, fmt.Sprintf("%d", int64(val)))
} else {
q.Set(k, fmt.Sprintf("%v", val))
}
case float32:
q.Set(k, fmt.Sprintf("%v", val))
case int, int8, int16, int32, int64:
q.Set(k, fmt.Sprintf("%d", val))
case uint, uint8, uint16, uint32, uint64:
q.Set(k, fmt.Sprintf("%d", val))
case []interface{}, map[string]interface{}:
// 数组或对象需要 JSON 序列化后 URL 编码
b, _ := json.Marshal(v)
q.Set(k, string(b))
default:
q.Set(k, fmt.Sprintf("%v", v))
}
}
parsed.RawQuery = q.Encode()
return parsed.String()
}
func (c *ApiClient) applyAuthURL(rawURL string) string {
cfg := c.config.AuthConfig
token := c.config.AccessToken
if cfg == nil {
return rawURL
}
tokenInQuery, _ := cfg["token_in_query"].(bool)
queryKey, _ := cfg["query_key"].(string)
if queryKey == "" {
queryKey = "access_token"
}
extraParams := make(map[string]string)
if eq, ok := cfg["extra_query_params"].(map[string]interface{}); ok {
for k, v := range eq {
val := fmt.Sprintf("%v", v)
val = strings.ReplaceAll(val, "{timestamp}", fmt.Sprintf("%d", time.Now().Unix()))
val = strings.ReplaceAll(val, "{nonce}", generateNonce())
extraParams[k] = val
}
}
if !tokenInQuery && len(extraParams) == 0 {
return rawURL
}
parsed, _ := url.Parse(rawURL)
q := parsed.Query()
if tokenInQuery && token != "" {
q.Set(queryKey, token)
}
for k, v := range extraParams {
q.Set(k, v)
}
parsed.RawQuery = q.Encode()
return parsed.String()
}
func (c *ApiClient) applyAuthHeader(req *http.Request) {
cfg := c.config.AuthConfig
token := c.config.AccessToken
if cfg != nil {
if tiq, _ := cfg["token_in_query"].(bool); tiq {
return
}
}
if token == "" {
return
}
if cfg != nil {
if h, ok := cfg["header_name"].(string); ok {
f := cfg["header_format"].(string)
if f == "" {
f = "{token}"
}
req.Header.Set(h, strings.ReplaceAll(f, "{token}", token))
return
}
}
switch c.config.AuthType {
case "OAUTH2", "TOKEN":
req.Header.Set("Authorization", "Bearer "+token)
case "API_KEY":
req.Header.Set("X-API-Key", token)
}
}
func generateNonce() string {
nanoPart := time.Now().UnixNano() % 1000000000000
r, _ := rand.Int(rand.Reader, big.NewInt(10000))
return fmt.Sprintf("%012d%04d", nanoPart, r.Int64())
}