Files
model-gateway/service/model_invoker.go

418 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-gateway/model/entity"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func parseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
// ============ 新增:请求参数映射 ============
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
if err != nil {
return nil, fmt.Errorf("请求参数映射失败: %w", err)
}
url := strings.TrimRight(m.BaseURL, "/")
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 先注入模型配置 head_msg静态头部适合公共模型固定 API Key
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
// 最后注入动态 modelKey允许覆盖/补充静态 head_msg适合按请求动态传密钥。
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// ============ 新增:响应参数映射 ============
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
if err != nil {
// 响应映射失败不阻塞,返回原始数据
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
// =========================================
return mappedResponse, nil
}
// ============================================
// 映射相关函数
// ============================================
// mapRequestPayload 将标准请求映射为模型特定格式
func mapRequestPayload(mappingAny any, payload any) (any, error) {
// 1. 解析请求映射配置值是any类型支持bool、number等
mapping, err := parseRequestMapping(mappingAny)
if err != nil {
return nil, err
}
// 如果没有映射配置直接返回原始payload
if len(mapping) == 0 {
return payload, nil
}
// 2. 将payload转为map
var payloadMap map[string]any
switch v := payload.(type) {
case map[string]any:
payloadMap = v
case []map[string]any:
// 如果传进来的是纯messages数组包装成标准格式
payloadMap = map[string]any{
"messages": v,
}
default:
// 通过JSON转换
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化payload失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
return nil, fmt.Errorf("反序列化payload失败: %w", err)
}
}
// 3. 用数据库固定参数覆盖/补充
for key, value := range mapping {
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
payloadMap[key] = value
}
}
return payloadMap, nil
}
// mapResponsePayload 将模型响应映射为标准格式
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
mapping, err := parseResponseMapping(mappingAny)
if err != nil {
return nil, err
}
if len(mapping) == 0 {
return responseBytes, nil
}
responseStr := string(responseBytes)
resultStr := `{}`
for standardField, modelPath := range mapping {
value := gjson.Get(responseStr, modelPath)
if !value.Exists() {
continue
}
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
if err != nil {
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
}
}
return []byte(resultStr), nil
}
func parseRequestMapping(mappingAny any) (map[string]any, error) {
if mappingAny == nil {
return nil, nil
}
result := make(map[string]any)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
// 尝试转成 map
if m := v.Map(); m != nil {
for k, val := range m {
result[k] = val
}
return result, nil
}
// 尝试转成 string
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
return result, nil
}
return nil, nil
// =======================================================
case map[string]interface{}:
result = v
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &result); err != nil {
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("解析映射配置失败: %w", err)
}
}
return result, nil
}
// parseResponseMapping 解析响应映射配置
// 返回值类型为 map[string]string值都是JSON路径字符串
func parseResponseMapping(mappingAny any) (map[string]string, error) {
if mappingAny == nil {
return nil, nil
}
mapping := make(map[string]string)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
if m := v.Map(); m != nil {
for k, val := range m {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
return mapping, nil
}
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
return mapping, nil
}
return nil, nil
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
case map[string]interface{}:
// 数据库JSONB直接返回的map
for k, val := range v {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
}
}
return mapping, nil
}
// isEmptyValue 判断值是否为空
func isEmptyValue(v any) bool {
if v == nil {
return true
}
switch val := v.(type) {
case string:
return val == ""
case []any:
return len(val) == 0
case map[string]any:
return len(val) == 0
default:
return false
}
}