2026-05-29 17:54:19 +08:00
|
|
|
|
package util
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"bytes"
|
|
|
|
|
|
"context"
|
2026-05-30 22:08:46 +08:00
|
|
|
|
"encoding/json"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"fmt"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"io"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"model-gateway/model/entity"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"net/http"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"net/url"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"regexp"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"strings"
|
2026-06-02 20:26:45 +08:00
|
|
|
|
"time"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
|
2026-06-03 18:37:17 +08:00
|
|
|
|
"gitea.com/red-future/common/utils"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
"github.com/gogf/gf/v2/encoding/gjson"
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
2026-05-30 22:08:46 +08:00
|
|
|
|
tgjson "github.com/tidwall/gjson"
|
2026-05-29 17:54:19 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-06-08 18:01:53 +08:00
|
|
|
|
// ParseAndValidate 解析并校验结果
|
|
|
|
|
|
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) {
|
|
|
|
|
|
// 1) 解析 content 字符串为 rounds 数组
|
|
|
|
|
|
contentVal, ok := raw[model.ResponseBody]
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
|
|
|
|
|
|
}
|
|
|
|
|
|
contentStr, ok := contentVal.(string)
|
|
|
|
|
|
if !ok || strings.TrimSpace(contentStr) == "" {
|
|
|
|
|
|
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
|
|
|
|
|
|
}
|
|
|
|
|
|
var arr []any
|
|
|
|
|
|
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
|
|
|
|
|
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(arr) == 0 {
|
|
|
|
|
|
return raw, fmt.Errorf("解析后数组为空")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2) 校验必填字段
|
|
|
|
|
|
if len(model.RequiredFields) > 0 {
|
|
|
|
|
|
for i, r := range arr {
|
|
|
|
|
|
round, ok := r.(map[string]any)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
for _, field := range model.RequiredFields {
|
|
|
|
|
|
if gjson.New(round).Get(field).IsNil() {
|
|
|
|
|
|
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ParseStructResult 解析结构结果
|
|
|
|
|
|
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
|
|
|
|
|
contentVal := raw[responseBody]
|
|
|
|
|
|
// 是字符串,尝试解析
|
|
|
|
|
|
contentStr := gconv.String(contentVal)
|
|
|
|
|
|
if contentStr == "" || contentStr == "0" {
|
|
|
|
|
|
return map[string]any{
|
|
|
|
|
|
"total_rounds": 1,
|
|
|
|
|
|
"rounds": []map[string]any{{responseBody: raw}},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 尝试解析为数组
|
|
|
|
|
|
var arr []any
|
|
|
|
|
|
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
|
|
|
|
|
return map[string]any{
|
|
|
|
|
|
"total_rounds": 1,
|
|
|
|
|
|
"rounds": []map[string]any{{responseBody: arr}},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 尝试解析为单个对象
|
|
|
|
|
|
var parsed any
|
|
|
|
|
|
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
|
|
|
|
|
return map[string]any{
|
|
|
|
|
|
"total_rounds": 1,
|
|
|
|
|
|
"rounds": []map[string]any{{responseBody: parsed}},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 兜底:原始字符串作为内容
|
|
|
|
|
|
return map[string]any{
|
|
|
|
|
|
"total_rounds": 1,
|
|
|
|
|
|
"rounds": []map[string]any{{responseBody: contentStr}},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-29 17:54:19 +08:00
|
|
|
|
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
2026-06-08 18:01:53 +08:00
|
|
|
|
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
|
2026-05-29 17:54:19 +08:00
|
|
|
|
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
2026-06-08 18:01:53 +08:00
|
|
|
|
// 1) 获取 rounds
|
|
|
|
|
|
roundsRaw, ok := raw["rounds"]
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return fmt.Errorf("缺少 rounds 字段")
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
2026-06-08 18:01:53 +08:00
|
|
|
|
rounds, ok := roundsRaw.([]any)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return fmt.Errorf("rounds 不是数组")
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
if len(rounds) == 0 {
|
2026-06-08 18:01:53 +08:00
|
|
|
|
return fmt.Errorf("rounds 数组为空")
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-08 18:01:53 +08:00
|
|
|
|
// 2) 没有配置必填字段,跳过
|
|
|
|
|
|
if len(model.RequiredFields) == 0 {
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3) 逐条校验
|
|
|
|
|
|
for i, r := range rounds {
|
|
|
|
|
|
round, ok := r.(map[string]any)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
for _, field := range model.RequiredFields {
|
|
|
|
|
|
if gjson.New(round).Get(field).IsNil() {
|
|
|
|
|
|
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-06-08 18:01:53 +08:00
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// validateRequiredFields 校验单个 round 对象的必选字段
|
|
|
|
|
|
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
|
|
|
|
|
|
for _, field := range requiredFields {
|
|
|
|
|
|
if gjson.New(round).Get(field).IsNil() {
|
|
|
|
|
|
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-05-29 17:54:19 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
|
|
|
|
|
// head_msg 格式示例:
|
|
|
|
|
|
//
|
|
|
|
|
|
// {
|
|
|
|
|
|
// "Authorization": "Bearer xxx",
|
|
|
|
|
|
// "Content-Type": "application/json",
|
|
|
|
|
|
// "X-Api-App-Id": "5147401364",
|
|
|
|
|
|
// "X-Api-Access-Key": "VCqRX7..."
|
|
|
|
|
|
// }
|
|
|
|
|
|
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
|
|
|
|
|
if len(headMsg) == 0 {
|
|
|
|
|
|
return nil
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
out := make(map[string]string, len(headMsg))
|
|
|
|
|
|
for k, v := range headMsg {
|
|
|
|
|
|
out[k] = gconv.String(v)
|
|
|
|
|
|
}
|
|
|
|
|
|
return out
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// MapResponsePayload 映射模型响应为标准格式
|
2026-05-30 22:08:46 +08:00
|
|
|
|
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
if len(mapping) == 0 {
|
2026-05-30 22:08:46 +08:00
|
|
|
|
return result, nil
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-30 22:08:46 +08:00
|
|
|
|
// 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入
|
|
|
|
|
|
resultBytes, _ := json.Marshal(result)
|
|
|
|
|
|
resultStr := string(resultBytes)
|
|
|
|
|
|
|
|
|
|
|
|
mapped := make(map[string]any)
|
2026-05-29 17:54:19 +08:00
|
|
|
|
|
|
|
|
|
|
for standardField, modelPath := range mapping {
|
|
|
|
|
|
path := gconv.String(modelPath)
|
|
|
|
|
|
if path == "" {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-05-30 22:08:46 +08:00
|
|
|
|
|
|
|
|
|
|
value := tgjson.Get(resultStr, path)
|
|
|
|
|
|
if !value.Exists() {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
continue
|
|
|
|
|
|
}
|
2026-05-30 22:08:46 +08:00
|
|
|
|
// 如果是数组路径(含 #),取 Array;否则取单值
|
|
|
|
|
|
if strings.Contains(path, "#") {
|
|
|
|
|
|
var arr []any
|
|
|
|
|
|
for _, v := range value.Array() {
|
|
|
|
|
|
arr = append(arr, v.Value())
|
|
|
|
|
|
}
|
|
|
|
|
|
mapped[standardField] = arr
|
|
|
|
|
|
} else {
|
|
|
|
|
|
mapped[standardField] = value.Value()
|
|
|
|
|
|
}
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-05-30 22:08:46 +08:00
|
|
|
|
return mapped, nil
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// GetModelBody 获取数据库中保存的模型信息
|
|
|
|
|
|
func GetModelBody(v map[string]any) map[string]any {
|
|
|
|
|
|
if v == nil {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
return nil
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if p, ok := v["body"]; ok {
|
|
|
|
|
|
return gconv.Map(p)
|
|
|
|
|
|
}
|
|
|
|
|
|
return v
|
2026-05-29 17:54:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
// BodyToQuery 将 body 转为 url.Values
|
|
|
|
|
|
func BodyToQuery(payload map[string]any) (url.Values, error) {
|
2026-05-29 17:54:19 +08:00
|
|
|
|
q := url.Values{}
|
|
|
|
|
|
for k, v := range payload {
|
|
|
|
|
|
if v == nil {
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
q.Set(k, gconv.String(v))
|
|
|
|
|
|
}
|
|
|
|
|
|
return q, nil
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// PullTaskResult 轮询查询异步任务结果直到完成
|
|
|
|
|
|
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
|
|
|
|
|
|
// 1) 解析配置
|
|
|
|
|
|
// 1.1 提取 taskID
|
|
|
|
|
|
taskIDPath := gconv.String(queryConfig["task_id"])
|
|
|
|
|
|
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
|
|
|
|
|
|
if taskID == "" {
|
|
|
|
|
|
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 1.2 请求地址,替换 {id}
|
|
|
|
|
|
queryUrl := gconv.String(queryConfig["url"])
|
|
|
|
|
|
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
|
|
|
|
|
|
|
|
|
|
|
|
// 1.3 请求方式
|
|
|
|
|
|
method := gconv.String(queryConfig["method"])
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if method == "" {
|
|
|
|
|
|
method = "GET"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 1.4 状态判断配置
|
|
|
|
|
|
statusPath := gconv.String(queryConfig["status_path"])
|
|
|
|
|
|
statusValues, _ := queryConfig["status_values"].(map[string]any)
|
|
|
|
|
|
if statusPath == "" {
|
|
|
|
|
|
statusPath = "status"
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 1.5 轮询间隔
|
|
|
|
|
|
interval := gconv.Int(queryConfig["interval_seconds"])
|
|
|
|
|
|
if interval <= 0 {
|
|
|
|
|
|
interval = 2
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 1.6 请求体
|
|
|
|
|
|
reqBodyMap := map[string]any{"task_id": taskID}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 2) 轮询请求
|
2026-06-02 20:26:45 +08:00
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
|
return nil, ctx.Err()
|
|
|
|
|
|
default:
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var reqBody io.Reader
|
2026-06-03 13:30:39 +08:00
|
|
|
|
if method == "POST" {
|
|
|
|
|
|
bs, _ := json.Marshal(reqBodyMap)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
reqBody = bytes.NewReader(bs)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
// 统一用 headMsg 注入请求头
|
|
|
|
|
|
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
|
|
|
|
|
|
req.Header.Set(hk, hv)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
|
|
|
|
resp, err := client.Do(req)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
raw, _ := io.ReadAll(resp.Body)
|
|
|
|
|
|
_ = resp.Body.Close()
|
|
|
|
|
|
|
|
|
|
|
|
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
|
|
|
|
|
|
|
2026-06-02 20:26:45 +08:00
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
|
|
|
|
|
continue
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
var result map[string]any
|
2026-06-03 13:30:39 +08:00
|
|
|
|
_ = json.Unmarshal(raw, &result)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
statusVal := gjson.New(result).Get(statusPath).Val()
|
|
|
|
|
|
statusStr := gconv.String(statusVal)
|
|
|
|
|
|
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
if matchStatus(statusStr, statusValues["succeeded"]) {
|
|
|
|
|
|
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
|
|
|
|
|
|
if matchStatus(statusStr, statusValues["failed"]) {
|
|
|
|
|
|
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
|
|
|
|
|
|
return result, fmt.Errorf("任务失败")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
time.Sleep(time.Duration(interval) * time.Second)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-06-03 13:30:39 +08:00
|
|
|
|
func matchStatus(actual string, expected any) bool {
|
2026-06-05 11:00:04 +08:00
|
|
|
|
expectedStr := gconv.String(expected)
|
|
|
|
|
|
if actual == expectedStr {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
switch v := expected.(type) {
|
|
|
|
|
|
case []any:
|
|
|
|
|
|
for _, item := range v {
|
|
|
|
|
|
if actual == gconv.String(item) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-06-02 20:26:45 +08:00
|
|
|
|
}
|
2026-06-03 13:30:39 +08:00
|
|
|
|
return false
|
2026-06-02 20:26:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// replaceURLParams 替换 URL 中的 {key}
|
|
|
|
|
|
func replaceURLParams(url string, params map[string]any) string {
|
2026-06-03 13:30:39 +08:00
|
|
|
|
re := regexp.MustCompile(`\{([^}]+)}`)
|
2026-06-02 20:26:45 +08:00
|
|
|
|
return re.ReplaceAllStringFunc(url, func(s string) string {
|
|
|
|
|
|
key := strings.Trim(s, "{}")
|
|
|
|
|
|
if val, ok := params[key]; ok {
|
|
|
|
|
|
return gconv.String(val)
|
|
|
|
|
|
}
|
|
|
|
|
|
return s
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// InjectCallbackURL 将回调地址注入到请求体中
|
|
|
|
|
|
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
|
|
|
|
|
if callbackURL == "" {
|
|
|
|
|
|
return payload
|
|
|
|
|
|
}
|
2026-06-03 18:37:17 +08:00
|
|
|
|
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
|
2026-06-02 20:26:45 +08:00
|
|
|
|
return payload
|
|
|
|
|
|
}
|