290 lines
7.4 KiB
Go
290 lines
7.4 KiB
Go
package util
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"model-gateway/model/entity"
|
||
"net/http"
|
||
"net/url"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
tgjson "github.com/tidwall/gjson"
|
||
)
|
||
|
||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||
// 1) 获取校验配置,并取值
|
||
requestMapping := model.RequestMapping
|
||
contentStr, ok := raw[model.ResponseBody].(string)
|
||
if !ok || contentStr == "" {
|
||
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
|
||
}
|
||
|
||
// 2) 解析 content 为 JSON 数组
|
||
var rounds []map[string]any
|
||
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||
}
|
||
if len(rounds) == 0 {
|
||
return fmt.Errorf("content 数组为空")
|
||
}
|
||
|
||
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||
for i, round := range rounds {
|
||
for path, defaultValue := range requestMapping {
|
||
if !g.IsEmpty(defaultValue) {
|
||
continue
|
||
}
|
||
if gjson.New(round).Get(path).IsNil() {
|
||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ReverseMap 映射 payload 到 mapping
|
||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||
jsonObj := gjson.New("{}")
|
||
for path, defaultValue := range mapping {
|
||
// 从 payload 取对应路径的值
|
||
val := gjson.New(payload).Get(path)
|
||
if !val.IsNil() {
|
||
// payload 有值,用它
|
||
_ = jsonObj.Set(path, val.Val())
|
||
} else if !g.IsEmpty(defaultValue) {
|
||
// payload 没值,用默认值
|
||
_ = jsonObj.Set(path, defaultValue)
|
||
}
|
||
}
|
||
return jsonObj.Map()
|
||
}
|
||
|
||
// MapResponsePayload 映射模型响应为标准格式
|
||
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
|
||
if len(mapping) == 0 {
|
||
return result, nil
|
||
}
|
||
|
||
// 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入
|
||
resultBytes, _ := json.Marshal(result)
|
||
resultStr := string(resultBytes)
|
||
|
||
mapped := make(map[string]any)
|
||
|
||
for standardField, modelPath := range mapping {
|
||
path := gconv.String(modelPath)
|
||
if path == "" {
|
||
continue
|
||
}
|
||
|
||
value := tgjson.Get(resultStr, path)
|
||
if !value.Exists() {
|
||
continue
|
||
}
|
||
// 如果是数组路径(含 #),取 Array;否则取单值
|
||
if strings.Contains(path, "#") {
|
||
var arr []any
|
||
for _, v := range value.Array() {
|
||
arr = append(arr, v.Value())
|
||
}
|
||
mapped[standardField] = arr
|
||
} else {
|
||
mapped[standardField] = value.Value()
|
||
}
|
||
}
|
||
|
||
return mapped, nil
|
||
}
|
||
|
||
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
||
// head_msg 格式示例:
|
||
//
|
||
// {
|
||
// "Authorization": "Bearer xxx",
|
||
// "Content-Type": "application/json",
|
||
// "X-Api-App-Id": "5147401364",
|
||
// "X-Api-Access-Key": "VCqRX7..."
|
||
// }
|
||
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
||
if len(headMsg) == 0 {
|
||
return nil
|
||
}
|
||
out := make(map[string]string, len(headMsg))
|
||
for k, v := range headMsg {
|
||
out[k] = gconv.String(v)
|
||
}
|
||
return out
|
||
}
|
||
|
||
// GetModelBody 获取数据库中保存的模型信息
|
||
func GetModelBody(v map[string]any) map[string]any {
|
||
if v == nil {
|
||
return nil
|
||
}
|
||
if p, ok := v["body"]; ok {
|
||
return gconv.Map(p)
|
||
}
|
||
return v
|
||
}
|
||
|
||
// BodyToQuery 将 body 转为 url.Values
|
||
func BodyToQuery(payload map[string]any) (url.Values, error) {
|
||
q := url.Values{}
|
||
for k, v := range payload {
|
||
if v == nil {
|
||
continue
|
||
}
|
||
q.Set(k, gconv.String(v))
|
||
}
|
||
return q, nil
|
||
}
|
||
|
||
// PullTaskResult 轮询查询任务结果直到完成
|
||
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
|
||
// 1. 解析配置
|
||
url := gconv.String(queryConfig["url"])
|
||
method := gconv.String(queryConfig["method"])
|
||
headers, _ := queryConfig["headers"].(map[string]any)
|
||
interval := gconv.Int(queryConfig["interval_seconds"])
|
||
if interval <= 0 {
|
||
interval = 2
|
||
}
|
||
|
||
if method == "" {
|
||
method = "GET"
|
||
}
|
||
|
||
// 2. 构建参数
|
||
params := map[string]any{"id": taskID}
|
||
|
||
// 3. 替换 URL 中的 {id}
|
||
finalURL := replaceURLParams(url, params)
|
||
|
||
// 4. 构建请求体
|
||
bodyCfg, _ := queryConfig["body"].(map[string]any)
|
||
body := buildParams(bodyCfg, params)
|
||
|
||
// 5. 轮询
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return nil, ctx.Err()
|
||
default:
|
||
}
|
||
|
||
var reqBody io.Reader
|
||
if method == "POST" && body != nil {
|
||
bs, _ := json.Marshal(body)
|
||
reqBody = bytes.NewReader(bs)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||
}
|
||
|
||
for k, v := range headers {
|
||
req.Header.Set(k, gconv.String(v))
|
||
}
|
||
if req.Header.Get("Content-Type") == "" && reqBody != nil {
|
||
req.Header.Set("Content-Type", "application/json")
|
||
}
|
||
|
||
client := &http.Client{Timeout: 30 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
|
||
time.Sleep(time.Duration(interval) * time.Second)
|
||
continue
|
||
}
|
||
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
all, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
|
||
time.Sleep(time.Duration(interval) * time.Second)
|
||
continue
|
||
}
|
||
|
||
var result map[string]any
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
resp.Body.Close()
|
||
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
|
||
time.Sleep(time.Duration(interval) * time.Second)
|
||
continue
|
||
}
|
||
resp.Body.Close()
|
||
|
||
status := gconv.String(result["status"])
|
||
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
|
||
|
||
switch status {
|
||
case "succeeded":
|
||
return result, nil
|
||
case "failed", "expired":
|
||
return result, fmt.Errorf("任务失败: status=%s", status)
|
||
case "queued", "running":
|
||
time.Sleep(time.Duration(interval) * time.Second)
|
||
continue
|
||
default:
|
||
// 兼容没有 status 字段的情况,直接返回
|
||
return result, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
|
||
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
||
result := make(map[string]any, len(bodyCfg)+len(params))
|
||
for k, v := range bodyCfg {
|
||
result[k] = v
|
||
}
|
||
for k, v := range params {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// replaceURLParams 替换 URL 中的 {key}
|
||
func replaceURLParams(url string, params map[string]any) string {
|
||
re := regexp.MustCompile(`\{([^}]+)\}`)
|
||
return re.ReplaceAllStringFunc(url, func(s string) string {
|
||
key := strings.Trim(s, "{}")
|
||
if val, ok := params[key]; ok {
|
||
return gconv.String(val)
|
||
}
|
||
return s
|
||
})
|
||
}
|
||
|
||
// replaceBodyParams 用 params 覆盖 body 中对应 key
|
||
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
||
result := make(map[string]any)
|
||
for k, v := range bodyCfg {
|
||
result[k] = v
|
||
}
|
||
for k, v := range params {
|
||
result[k] = v
|
||
}
|
||
return result
|
||
}
|
||
|
||
// InjectCallbackURL 将回调地址注入到请求体中
|
||
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
||
if callbackURL == "" {
|
||
return payload
|
||
}
|
||
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
|
||
return payload
|
||
}
|