first commit
This commit is contained in:
151
service/model_invoker.go
Normal file
151
service/model_invoker.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
func parseAPIKeyHeader(apiKey string) (k, v string) {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return "", ""
|
||||
}
|
||||
// 支持两种写法:
|
||||
// 1) HeaderName:HeaderValue(推荐)
|
||||
// 2) HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(apiKey, ":") {
|
||||
parts := strings.SplitN(apiKey, ":", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
if strings.Contains(apiKey, "=") {
|
||||
parts := strings.SplitN(apiKey, "=", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
// 只给了 value:不做注入(避免注入非法 header)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
if payload == nil {
|
||||
return url.Values{}, nil
|
||||
}
|
||||
// 统一转成 map[string]any
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := map[string]any{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := url.Values{}
|
||||
for k, v := range m {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
// 复杂类型直接 json 字符串化
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
q.Set(k, vv)
|
||||
case float64, bool, int, int64, uint64:
|
||||
q.Set(k, fmt.Sprintf("%v", vv))
|
||||
default:
|
||||
bs, _ := json.Marshal(v)
|
||||
q.Set(k, string(bs))
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byte, error) {
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
||||
if strings.TrimSpace(m.Route) == "" {
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
}
|
||||
|
||||
timeout := time.Duration(m.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(url, "?") {
|
||||
url = url + "&" + q.Encode()
|
||||
} else {
|
||||
url = url + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 透传必要头部(如 Authorization / X-User-Info),以及注入模型配置里的 api_key
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// 尽量把错误体带回去,方便排查
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
Reference in New Issue
Block a user