2026-05-29 17:54:19 +08:00
|
|
|
|
package util
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"bytes"
|
|
|
|
|
|
"context"
|
2026-05-30 22:08:46 +08:00
|
|
|
|
"encoding/json"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"fmt"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"io"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"model-gateway/model/entity"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"net/http"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"net/url"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"regexp"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"strings"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"time"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/gogf/gf/v2/encoding/gjson"
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
2026-05-30 22:08:46 +08:00
|
|
|
|
tgjson "github.com/tidwall/gjson"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
|
|
|
|
|
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
|
|
|
|
|
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
|
|
|
|
|
// 1) 获取校验配置,并取值
|
|
|
|
|
|
requestMapping := model.RequestMapping
|
2026-06-02 20:26:45 +08:00
|
|
|
|
contentStr, ok := raw[model.ResponseBody].(string)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
if !ok || contentStr == "" {
|
2026-06-02 20:26:45 +08:00
|
|
|
|
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 解析 content 为 JSON 数组
|
|
|
|
|
|
var rounds []map[string]any
|
|
|
|
|
|
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
|
|
|
|
|
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(rounds) == 0 {
|
|
|
|
|
|
return fmt.Errorf("content 数组为空")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
|
|
|
|
|
for i, round := range rounds {
|
|
|
|
|
|
for path, defaultValue := range requestMapping {
|
|
|
|
|
|
if !g.IsEmpty(defaultValue) {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
if gjson.New(round).Get(path).IsNil() {
|
|
|
|
|
|
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ReverseMap 映射 payload 到 mapping
|
|
|
|
|
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
|
|
|
|
|
jsonObj := gjson.New("{}")
|
|
|
|
|
|
for path, defaultValue := range mapping {
|
|
|
|
|
|
// 从 payload 取对应路径的值
|
|
|
|
|
|
val := gjson.New(payload).Get(path)
|
|
|
|
|
|
if !val.IsNil() {
|
|
|
|
|
|
// payload 有值,用它
|
|
|
|
|
|
_ = jsonObj.Set(path, val.Val())
|
|
|
|
|
|
} else if !g.IsEmpty(defaultValue) {
|
|
|
|
|
|
// payload 没值,用默认值
|
|
|
|
|
|
_ = jsonObj.Set(path, defaultValue)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return jsonObj.Map()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// MapResponsePayload 映射模型响应为标准格式
|
2026-05-30 22:08:46 +08:00
|
|
|
|
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
if len(mapping) == 0 {
|
2026-05-30 22:08:46 +08:00
|
|
|
|
return result, nil
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-30 22:08:46 +08:00
|
|
|
|
// 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入
|
|
|
|
|
|
resultBytes, _ := json.Marshal(result)
|
|
|
|
|
|
resultStr := string(resultBytes)
|
|
|
|
|
|
|
|
|
|
|
|
mapped := make(map[string]any)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
|
|
|
|
|
|
for standardField, modelPath := range mapping {
|
|
|
|
|
|
path := gconv.String(modelPath)
|
|
|
|
|
|
if path == "" {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-05-30 22:08:46 +08:00
|
|
|
|
|
|
|
|
|
|
value := tgjson.Get(resultStr, path)
|
|
|
|
|
|
if !value.Exists() {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-05-30 22:08:46 +08:00
|
|
|
|
// 如果是数组路径(含 #),取 Array;否则取单值
|
|
|
|
|
|
if strings.Contains(path, "#") {
|
|
|
|
|
|
var arr []any
|
|
|
|
|
|
for _, v := range value.Array() {
|
|
|
|
|
|
arr = append(arr, v.Value())
|
|
|
|
|
|
}
|
|
|
|
|
|
mapped[standardField] = arr
|
|
|
|
|
|
} else {
|
|
|
|
|
|
mapped[standardField] = value.Value()
|
|
|
|
|
|
}
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-30 22:08:46 +08:00
|
|
|
|
return mapped, nil
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
|
|
|
|
|
// head_msg 格式示例:
|
2026-05-29 17:54:19 +08:00
|
|
|
|
//
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// {
|
|
|
|
|
|
// "Authorization": "Bearer xxx",
|
|
|
|
|
|
// "Content-Type": "application/json",
|
|
|
|
|
|
// "X-Api-App-Id": "5147401364",
|
|
|
|
|
|
// "X-Api-Access-Key": "VCqRX7..."
|
|
|
|
|
|
// }
|
|
|
|
|
|
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
|
|
|
|
|
if len(headMsg) == 0 {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
out := make(map[string]string, len(headMsg))
|
|
|
|
|
|
for k, v := range headMsg {
|
|
|
|
|
|
out[k] = gconv.String(v)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
return out
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetModelBody 获取数据库中保存的模型信息
|
|
|
|
|
|
func GetModelBody(v map[string]any) map[string]any {
|
|
|
|
|
|
if v == nil {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if p, ok := v["body"]; ok {
|
|
|
|
|
|
return gconv.Map(p)
|
|
|
|
|
|
}
|
|
|
|
|
|
return v
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// BodyToQuery 将 body 转为 url.Values
|
|
|
|
|
|
func BodyToQuery(payload map[string]any) (url.Values, error) {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
q := url.Values{}
|
|
|
|
|
|
for k, v := range payload {
|
|
|
|
|
|
if v == nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
q.Set(k, gconv.String(v))
|
|
|
|
|
|
}
|
|
|
|
|
|
return q, nil
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
|
|
|
|
|
// PullTaskResult 轮询查询任务结果直到完成
|
|
|
|
|
|
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
|
|
|
|
|
|
// 1. 解析配置
|
|
|
|
|
|
url := gconv.String(queryConfig["url"])
|
|
|
|
|
|
method := gconv.String(queryConfig["method"])
|
|
|
|
|
|
headers, _ := queryConfig["headers"].(map[string]any)
|
|
|
|
|
|
interval := gconv.Int(queryConfig["interval_seconds"])
|
|
|
|
|
|
if interval <= 0 {
|
|
|
|
|
|
interval = 2
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if method == "" {
|
|
|
|
|
|
method = "GET"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 构建参数
|
|
|
|
|
|
params := map[string]any{"id": taskID}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 替换 URL 中的 {id}
|
|
|
|
|
|
finalURL := replaceURLParams(url, params)
|
|
|
|
|
|
|
|
|
|
|
|
// 4. 构建请求体
|
|
|
|
|
|
bodyCfg, _ := queryConfig["body"].(map[string]any)
|
|
|
|
|
|
body := buildParams(bodyCfg, params)
|
|
|
|
|
|
|
|
|
|
|
|
// 5. 轮询
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
|
return nil, ctx.Err()
|
|
|
|
|
|
default:
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var reqBody io.Reader
|
|
|
|
|
|
if method == "POST" && body != nil {
|
|
|
|
|
|
bs, _ := json.Marshal(body)
|
|
|
|
|
|
reqBody = bytes.NewReader(bs)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for k, v := range headers {
|
|
|
|
|
|
req.Header.Set(k, gconv.String(v))
|
|
|
|
|
|
}
|
|
|
|
|
|
if req.Header.Get("Content-Type") == "" && reqBody != nil {
|
|
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
|
|
|
|
all, _ := io.ReadAll(resp.Body)
|
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
|
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var result map[string]any
|
|
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
|
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
|
|
status := gconv.String(result["status"])
|
|
|
|
|
|
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
|
|
|
|
|
|
|
|
|
|
|
|
switch status {
|
|
|
|
|
|
case "succeeded":
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
case "failed", "expired":
|
|
|
|
|
|
return result, fmt.Errorf("任务失败: status=%s", status)
|
|
|
|
|
|
case "queued", "running":
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
default:
|
|
|
|
|
|
// 兼容没有 status 字段的情况,直接返回
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
|
|
|
|
|
|
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
|
|
|
|
|
result := make(map[string]any, len(bodyCfg)+len(params))
|
|
|
|
|
|
for k, v := range bodyCfg {
|
|
|
|
|
|
result[k] = v
|
|
|
|
|
|
}
|
|
|
|
|
|
for k, v := range params {
|
|
|
|
|
|
result[k] = v
|
|
|
|
|
|
}
|
|
|
|
|
|
return result
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceURLParams 替换 URL 中的 {key}
|
|
|
|
|
|
func replaceURLParams(url string, params map[string]any) string {
|
|
|
|
|
|
re := regexp.MustCompile(`\{([^}]+)\}`)
|
|
|
|
|
|
return re.ReplaceAllStringFunc(url, func(s string) string {
|
|
|
|
|
|
key := strings.Trim(s, "{}")
|
|
|
|
|
|
if val, ok := params[key]; ok {
|
|
|
|
|
|
return gconv.String(val)
|
|
|
|
|
|
}
|
|
|
|
|
|
return s
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceBodyParams 用 params 覆盖 body 中对应 key
|
|
|
|
|
|
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
|
|
|
|
|
result := make(map[string]any)
|
|
|
|
|
|
for k, v := range bodyCfg {
|
|
|
|
|
|
result[k] = v
|
|
|
|
|
|
}
|
|
|
|
|
|
for k, v := range params {
|
|
|
|
|
|
result[k] = v
|
|
|
|
|
|
}
|
|
|
|
|
|
return result
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// InjectCallbackURL 将回调地址注入到请求体中
|
|
|
|
|
|
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
|
|
|
|
|
if callbackURL == "" {
|
|
|
|
|
|
return payload
|
|
|
|
|
|
}
|
|
|
|
|
|
payload[callbackURL] = GetCallbackURL(ctx, "/task/modelCallback")
|
|
|
|
|
|
return payload
|
|
|
|
|
|
}
|