gatway
This commit is contained in:
@@ -12,6 +12,11 @@ import (
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
@@ -100,11 +105,14 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
||||
if strings.TrimSpace(m.Route) == "" {
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
|
||||
// ============ 新增:请求参数映射 ============
|
||||
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(m.BaseURL, "/")
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
@@ -118,11 +126,10 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(payload)
|
||||
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,7 +142,7 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -145,20 +152,16 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先注入模型配置 head_msg(静态头部)
|
||||
// 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key)
|
||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
// 透传必要头部(如 Authorization / X-User-Info)
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
// 最后注入动态 modelKey(覆盖/补充静态 head_msg)
|
||||
|
||||
// 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。
|
||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
@@ -174,12 +177,241 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// 尽量把错误体带回去,方便排查
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
|
||||
// ============ 新增:响应参数映射 ============
|
||||
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
|
||||
if err != nil {
|
||||
// 响应映射失败不阻塞,返回原始数据
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
// =========================================
|
||||
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 映射相关函数
|
||||
// ============================================
|
||||
|
||||
// mapRequestPayload 将标准请求映射为模型特定格式
|
||||
func mapRequestPayload(mappingAny any, payload any) (any, error) {
|
||||
// 1. 解析请求映射配置(值是any类型,支持bool、number等)
|
||||
mapping, err := parseRequestMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果没有映射配置,直接返回原始payload
|
||||
if len(mapping) == 0 {
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 2. 将payload转为map
|
||||
var payloadMap map[string]any
|
||||
switch v := payload.(type) {
|
||||
case map[string]any:
|
||||
payloadMap = v
|
||||
case []map[string]any:
|
||||
// 如果传进来的是纯messages数组,包装成标准格式
|
||||
payloadMap = map[string]any{
|
||||
"messages": v,
|
||||
}
|
||||
default:
|
||||
// 通过JSON转换
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
|
||||
return nil, fmt.Errorf("反序列化payload失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用数据库固定参数覆盖/补充
|
||||
for key, value := range mapping {
|
||||
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
|
||||
payloadMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return payloadMap, nil
|
||||
}
|
||||
|
||||
// mapResponsePayload 将模型响应映射为标准格式
|
||||
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
|
||||
mapping, err := parseResponseMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseStr := string(responseBytes)
|
||||
resultStr := `{}`
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
value := gjson.Get(responseStr, modelPath)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(resultStr), nil
|
||||
}
|
||||
|
||||
func parseRequestMapping(mappingAny any) (map[string]any, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
// 尝试转成 map
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
result[k] = val
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
// 尝试转成 string
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return nil, nil
|
||||
// =======================================================
|
||||
|
||||
case map[string]interface{}:
|
||||
result = v
|
||||
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseResponseMapping 解析响应映射配置
|
||||
// 返回值类型为 map[string]string,值都是JSON路径字符串
|
||||
func parseResponseMapping(mappingAny any) (map[string]string, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mapping := make(map[string]string)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
return nil, nil
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case map[string]interface{}:
|
||||
// 数据库JSONB直接返回的map
|
||||
for k, val := range v {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
// isEmptyValue 判断值是否为空
|
||||
func isEmptyValue(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val == ""
|
||||
case []any:
|
||||
return len(val) == 0
|
||||
case map[string]any:
|
||||
return len(val) == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user