Files
model-gateway/common/util/mapping.go

290 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
tgjson "github.com/tidwall/gjson"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentStr, ok := raw[model.ResponseBody].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
}
// 2) 解析 content 为 JSON 数组
var rounds []map[string]any
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
for i, round := range rounds {
for path, defaultValue := range requestMapping {
if !g.IsEmpty(defaultValue) {
continue
}
if gjson.New(round).Get(path).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
}
}
}
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
// 从 payload 取对应路径的值
val := gjson.New(payload).Get(path)
if !val.IsNil() {
// payload 有值,用它
_ = jsonObj.Set(path, val.Val())
} else if !g.IsEmpty(defaultValue) {
// payload 没值,用默认值
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
if len(mapping) == 0 {
return result, nil
}
// 把 result 转成 JSON 字符串tidwall/gjson 需要字符串输入
resultBytes, _ := json.Marshal(result)
resultStr := string(resultBytes)
mapped := make(map[string]any)
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
value := tgjson.Get(resultStr, path)
if !value.Exists() {
continue
}
// 如果是数组路径(含 #),取 Array否则取单值
if strings.Contains(path, "#") {
var arr []any
for _, v := range value.Array() {
arr = append(arr, v.Value())
}
mapped[standardField] = arr
} else {
mapped[standardField] = value.Value()
}
}
return mapped, nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}
// PullTaskResult 轮询查询任务结果直到完成
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
// 1. 解析配置
url := gconv.String(queryConfig["url"])
method := gconv.String(queryConfig["method"])
headers, _ := queryConfig["headers"].(map[string]any)
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
if method == "" {
method = "GET"
}
// 2. 构建参数
params := map[string]any{"id": taskID}
// 3. 替换 URL 中的 {id}
finalURL := replaceURLParams(url, params)
// 4. 构建请求体
bodyCfg, _ := queryConfig["body"].(map[string]any)
body := buildParams(bodyCfg, params)
// 5. 轮询
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var reqBody io.Reader
if method == "POST" && body != nil {
bs, _ := json.Marshal(body)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range headers {
req.Header.Set(k, gconv.String(v))
}
if req.Header.Get("Content-Type") == "" && reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
all, _ := io.ReadAll(resp.Body)
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
resp.Body.Close()
status := gconv.String(result["status"])
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
switch status {
case "succeeded":
return result, nil
case "failed", "expired":
return result, fmt.Errorf("任务失败: status=%s", status)
case "queued", "running":
time.Sleep(time.Duration(interval) * time.Second)
continue
default:
// 兼容没有 status 字段的情况,直接返回
return result, nil
}
}
}
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any, len(bodyCfg)+len(params))
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)\}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
return gconv.String(val)
}
return s
})
}
// replaceBodyParams 用 params 覆盖 body 中对应 key
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
return payload
}