186 lines
4.4 KiB
Go
186 lines
4.4 KiB
Go
|
|
package service
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"io"
|
|||
|
|
"net/http"
|
|||
|
|
"net/url"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"model-asynch/model/entity"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
|||
|
|
// 示例:
|
|||
|
|
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
|||
|
|
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
|||
|
|
//
|
|||
|
|
// 说明:
|
|||
|
|
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
|||
|
|
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
|||
|
|
func parseHeadMsgHeaders(headMsg string) map[string]string {
|
|||
|
|
headMsg = strings.TrimSpace(headMsg)
|
|||
|
|
if headMsg == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
out := map[string]string{}
|
|||
|
|
parts := strings.Split(headMsg, ",")
|
|||
|
|
for _, p := range parts {
|
|||
|
|
p = strings.TrimSpace(p)
|
|||
|
|
if p == "" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
|||
|
|
if strings.Contains(p, ":") {
|
|||
|
|
kv := strings.SplitN(p, ":", 2)
|
|||
|
|
k := strings.TrimSpace(kv[0])
|
|||
|
|
v := strings.TrimSpace(kv[1])
|
|||
|
|
v = strings.Trim(v, "\"")
|
|||
|
|
if k != "" && v != "" {
|
|||
|
|
out[k] = v
|
|||
|
|
}
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if strings.Contains(p, "=") {
|
|||
|
|
kv := strings.SplitN(p, "=", 2)
|
|||
|
|
k := strings.TrimSpace(kv[0])
|
|||
|
|
v := strings.TrimSpace(kv[1])
|
|||
|
|
v = strings.Trim(v, "\"")
|
|||
|
|
if k != "" && v != "" {
|
|||
|
|
out[k] = v
|
|||
|
|
}
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if len(out) == 0 {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return out
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func payloadToQuery(payload any) (url.Values, error) {
|
|||
|
|
if payload == nil {
|
|||
|
|
return url.Values{}, nil
|
|||
|
|
}
|
|||
|
|
// 统一转成 map[string]any
|
|||
|
|
b, err := json.Marshal(payload)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
m := map[string]any{}
|
|||
|
|
if err := json.Unmarshal(b, &m); err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
q := url.Values{}
|
|||
|
|
for k, v := range m {
|
|||
|
|
if v == nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
// 复杂类型直接 json 字符串化
|
|||
|
|
switch vv := v.(type) {
|
|||
|
|
case string:
|
|||
|
|
q.Set(k, vv)
|
|||
|
|
case float64, bool, int, int64, uint64:
|
|||
|
|
q.Set(k, fmt.Sprintf("%v", vv))
|
|||
|
|
default:
|
|||
|
|
bs, _ := json.Marshal(v)
|
|||
|
|
q.Set(k, string(bs))
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return q, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// InvokeModel 调用模型服务,返回二进制结果
|
|||
|
|
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。
|
|||
|
|
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
|||
|
|
if m == nil || m.BaseURL == "" {
|
|||
|
|
return nil, fmt.Errorf("模型配置不完整")
|
|||
|
|
}
|
|||
|
|
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
|||
|
|
if strings.TrimSpace(m.Route) == "" {
|
|||
|
|
url = strings.TrimRight(m.BaseURL, "/")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
|||
|
|
if timeout <= 0 {
|
|||
|
|
timeout = 60 * time.Second
|
|||
|
|
}
|
|||
|
|
client := &http.Client{Timeout: timeout}
|
|||
|
|
|
|||
|
|
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
|||
|
|
if method == "" {
|
|||
|
|
method = http.MethodPost
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var (
|
|||
|
|
req *http.Request
|
|||
|
|
err error
|
|||
|
|
)
|
|||
|
|
switch method {
|
|||
|
|
case http.MethodGet:
|
|||
|
|
q, err := payloadToQuery(payload)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if len(q) > 0 {
|
|||
|
|
if strings.Contains(url, "?") {
|
|||
|
|
url = url + "&" + q.Encode()
|
|||
|
|
} else {
|
|||
|
|
url = url + "?" + q.Encode()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|||
|
|
default:
|
|||
|
|
bodyBytes, err := json.Marshal(payload)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
|||
|
|
}
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 先注入模型配置 head_msg(静态头部)
|
|||
|
|
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
|||
|
|
req.Header.Set(hk, hv)
|
|||
|
|
}
|
|||
|
|
// 透传必要头部(如 Authorization / X-User-Info)
|
|||
|
|
for k, v := range forwardHeaders(ctx) {
|
|||
|
|
if v != "" {
|
|||
|
|
req.Header.Set(k, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
// 最后注入动态 modelKey(覆盖/补充静态 head_msg)
|
|||
|
|
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
|||
|
|
req.Header.Set(hk, hv)
|
|||
|
|
}
|
|||
|
|
if method != http.MethodGet {
|
|||
|
|
req.Header.Set("Content-Type", "application/json")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
resp, err := client.Do(req)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
defer resp.Body.Close()
|
|||
|
|
|
|||
|
|
b, err := io.ReadAll(resp.Body)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|||
|
|
// 尽量把错误体带回去,方便排查
|
|||
|
|
msg := string(b)
|
|||
|
|
if len(msg) > 2000 {
|
|||
|
|
msg = msg[:2000]
|
|||
|
|
}
|
|||
|
|
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
|||
|
|
}
|
|||
|
|
return b, nil
|
|||
|
|
}
|