Files
model-gateway/service/model_invoker.go

418 lines
10 KiB
Go
Raw Normal View History

2026-04-29 15:54:14 +08:00
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-asynch/model/entity"
2026-05-12 13:45:08 +08:00
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
2026-04-29 15:54:14 +08:00
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func parseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
2026-05-12 13:45:08 +08:00
// ============ 新增:请求参数映射 ============
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
if err != nil {
return nil, fmt.Errorf("请求参数映射失败: %w", err)
2026-04-29 15:54:14 +08:00
}
2026-05-12 13:45:08 +08:00
url := strings.TrimRight(m.BaseURL, "/")
2026-04-29 15:54:14 +08:00
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
)
switch method {
case http.MethodGet:
2026-05-12 13:45:08 +08:00
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
2026-05-12 13:45:08 +08:00
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
2026-04-29 15:54:14 +08:00
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
2026-05-12 13:45:08 +08:00
// 先注入模型配置 head_msg静态头部适合公共模型固定 API Key
2026-04-29 15:54:14 +08:00
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
2026-05-12 13:45:08 +08:00
// 最后注入动态 modelKey允许覆盖/补充静态 head_msg适合按请求动态传密钥。
2026-04-29 15:54:14 +08:00
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
2026-05-12 13:45:08 +08:00
2026-04-29 15:54:14 +08:00
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
2026-05-12 13:45:08 +08:00
// ============ 新增:响应参数映射 ============
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
if err != nil {
// 响应映射失败不阻塞,返回原始数据
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
// =========================================
return mappedResponse, nil
}
// ============================================
// 映射相关函数
// ============================================
// mapRequestPayload 将标准请求映射为模型特定格式
func mapRequestPayload(mappingAny any, payload any) (any, error) {
// 1. 解析请求映射配置值是any类型支持bool、number等
mapping, err := parseRequestMapping(mappingAny)
if err != nil {
return nil, err
}
// 如果没有映射配置直接返回原始payload
if len(mapping) == 0 {
return payload, nil
}
// 2. 将payload转为map
var payloadMap map[string]any
switch v := payload.(type) {
case map[string]any:
payloadMap = v
case []map[string]any:
// 如果传进来的是纯messages数组包装成标准格式
payloadMap = map[string]any{
"messages": v,
}
default:
// 通过JSON转换
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化payload失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
return nil, fmt.Errorf("反序列化payload失败: %w", err)
}
}
// 3. 用数据库固定参数覆盖/补充
for key, value := range mapping {
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
payloadMap[key] = value
}
}
return payloadMap, nil
}
// mapResponsePayload 将模型响应映射为标准格式
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
mapping, err := parseResponseMapping(mappingAny)
if err != nil {
return nil, err
}
if len(mapping) == 0 {
return responseBytes, nil
}
responseStr := string(responseBytes)
resultStr := `{}`
for standardField, modelPath := range mapping {
value := gjson.Get(responseStr, modelPath)
if !value.Exists() {
continue
}
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
if err != nil {
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
}
}
return []byte(resultStr), nil
}
func parseRequestMapping(mappingAny any) (map[string]any, error) {
if mappingAny == nil {
return nil, nil
}
result := make(map[string]any)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
// 尝试转成 map
if m := v.Map(); m != nil {
for k, val := range m {
result[k] = val
}
return result, nil
}
// 尝试转成 string
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
return result, nil
}
return nil, nil
// =======================================================
case map[string]interface{}:
result = v
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &result); err != nil {
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("解析映射配置失败: %w", err)
}
}
return result, nil
}
// parseResponseMapping 解析响应映射配置
// 返回值类型为 map[string]string值都是JSON路径字符串
func parseResponseMapping(mappingAny any) (map[string]string, error) {
if mappingAny == nil {
return nil, nil
}
mapping := make(map[string]string)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
if m := v.Map(); m != nil {
for k, val := range m {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
return mapping, nil
}
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
return mapping, nil
}
return nil, nil
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
case map[string]interface{}:
// 数据库JSONB直接返回的map
for k, val := range v {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
}
}
return mapping, nil
}
// isEmptyValue 判断值是否为空
func isEmptyValue(v any) bool {
if v == nil {
return true
}
switch val := v.(type) {
case string:
return val == ""
case []any:
return len(val) == 0
case map[string]any:
return len(val) == 0
default:
return false
}
2026-04-29 15:54:14 +08:00
}