Files
model-gateway/common/util/mapping.go

290 lines
7.4 KiB
Go
Raw Normal View History

package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
tgjson "github.com/tidwall/gjson"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentStr, ok := raw[model.ResponseBody].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
}
// 2) 解析 content 为 JSON 数组
var rounds []map[string]any
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
for i, round := range rounds {
for path, defaultValue := range requestMapping {
if !g.IsEmpty(defaultValue) {
continue
}
if gjson.New(round).Get(path).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
}
}
}
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
// 从 payload 取对应路径的值
val := gjson.New(payload).Get(path)
if !val.IsNil() {
// payload 有值,用它
_ = jsonObj.Set(path, val.Val())
} else if !g.IsEmpty(defaultValue) {
// payload 没值,用默认值
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
if len(mapping) == 0 {
return result, nil
}
// 把 result 转成 JSON 字符串tidwall/gjson 需要字符串输入
resultBytes, _ := json.Marshal(result)
resultStr := string(resultBytes)
mapped := make(map[string]any)
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
value := tgjson.Get(resultStr, path)
if !value.Exists() {
continue
}
// 如果是数组路径(含 #),取 Array否则取单值
if strings.Contains(path, "#") {
var arr []any
for _, v := range value.Array() {
arr = append(arr, v.Value())
}
mapped[standardField] = arr
} else {
mapped[standardField] = value.Value()
}
}
return mapped, nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}
// PullTaskResult 轮询查询任务结果直到完成
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
// 1. 解析配置
url := gconv.String(queryConfig["url"])
method := gconv.String(queryConfig["method"])
headers, _ := queryConfig["headers"].(map[string]any)
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
if method == "" {
method = "GET"
}
// 2. 构建参数
params := map[string]any{"id": taskID}
// 3. 替换 URL 中的 {id}
finalURL := replaceURLParams(url, params)
// 4. 构建请求体
bodyCfg, _ := queryConfig["body"].(map[string]any)
body := buildParams(bodyCfg, params)
// 5. 轮询
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var reqBody io.Reader
if method == "POST" && body != nil {
bs, _ := json.Marshal(body)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
for k, v := range headers {
req.Header.Set(k, gconv.String(v))
}
if req.Header.Get("Content-Type") == "" && reqBody != nil {
req.Header.Set("Content-Type", "application/json")
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
all, _ := io.ReadAll(resp.Body)
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
resp.Body.Close()
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
resp.Body.Close()
status := gconv.String(result["status"])
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
switch status {
case "succeeded":
return result, nil
case "failed", "expired":
return result, fmt.Errorf("任务失败: status=%s", status)
case "queued", "running":
time.Sleep(time.Duration(interval) * time.Second)
continue
default:
// 兼容没有 status 字段的情况,直接返回
return result, nil
}
}
}
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any, len(bodyCfg)+len(params))
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)\}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
return gconv.String(val)
}
return s
})
}
// replaceBodyParams 用 params 覆盖 body 中对应 key
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
result := make(map[string]any)
for k, v := range bodyCfg {
result[k] = v
}
for k, v := range params {
result[k] = v
}
return result
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
return payload
}