From bcfcc7ed47abcebac130c1b0ee05ec2b0adf5332 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Wed, 3 Jun 2026 13:30:39 +0800 Subject: [PATCH] =?UTF-8?q?refactor(util):=20=E9=87=8D=E6=9E=84=E6=98=A0?= =?UTF-8?q?=E5=B0=84=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=BC=82=E6=AD=A5=E4=BB=BB=E5=8A=A1=E8=BD=AE=E8=AF=A2?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/mapping.go | 189 +++++++++++++++------------------ dao/model_dao.go | 12 --- model/dto/model_dto.go | 9 +- service/model/model_service.go | 14 +-- service/task/task_service.go | 2 +- service/task/worker.go | 4 +- 6 files changed, 99 insertions(+), 131 deletions(-) diff --git a/common/util/mapping.go b/common/util/mapping.go index f4acb86..d092e1a 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -52,21 +52,24 @@ func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { return nil } -// ReverseMap 映射 payload 到 mapping -func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { - jsonObj := gjson.New("{}") - for path, defaultValue := range mapping { - // 从 payload 取对应路径的值 - val := gjson.New(payload).Get(path) - if !val.IsNil() { - // payload 有值,用它 - _ = jsonObj.Set(path, val.Val()) - } else if !g.IsEmpty(defaultValue) { - // payload 没值,用默认值 - _ = jsonObj.Set(path, defaultValue) - } +// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 +// head_msg 格式示例: +// +// { +// "Authorization": "Bearer xxx", +// "Content-Type": "application/json", +// "X-Api-App-Id": "5147401364", +// "X-Api-Access-Key": "VCqRX7..." +// } +func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string { + if len(headMsg) == 0 { + return nil } - return jsonObj.Map() + out := make(map[string]string, len(headMsg)) + for k, v := range headMsg { + out[k] = gconv.String(v) + } + return out } // MapResponsePayload 映射模型响应为标准格式 @@ -106,26 +109,6 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri return mapped, nil } -// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 -// head_msg 格式示例: -// -// { -// "Authorization": "Bearer xxx", -// "Content-Type": "application/json", -// "X-Api-App-Id": "5147401364", -// "X-Api-Access-Key": "VCqRX7..." -// } -func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string { - if len(headMsg) == 0 { - return nil - } - out := make(map[string]string, len(headMsg)) - for k, v := range headMsg { - out[k] = gconv.String(v) - } - return out -} - // GetModelBody 获取数据库中保存的模型信息 func GetModelBody(v map[string]any) map[string]any { if v == nil { @@ -149,32 +132,44 @@ func BodyToQuery(payload map[string]any) (url.Values, error) { return q, nil } -// PullTaskResult 轮询查询任务结果直到完成 -func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) { - // 1. 解析配置 - url := gconv.String(queryConfig["url"]) +// PullTaskResult 轮询查询异步任务结果直到完成 +func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) { + // 1) 解析配置 + // 1.1 提取 taskID + taskIDPath := gconv.String(queryConfig["task_id"]) + taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val()) + if taskID == "" { + return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath) + } + g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID) + + // 1.2 请求地址,替换 {id} + queryUrl := gconv.String(queryConfig["url"]) + queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID}) + + // 1.3 请求方式 method := gconv.String(queryConfig["method"]) - headers, _ := queryConfig["headers"].(map[string]any) + if method == "" { + method = "GET" + } + + // 1.4 状态判断配置 + statusPath := gconv.String(queryConfig["status_path"]) + statusValues, _ := queryConfig["status_values"].(map[string]any) + if statusPath == "" { + statusPath = "status" + } + + // 1.5 轮询间隔 interval := gconv.Int(queryConfig["interval_seconds"]) if interval <= 0 { interval = 2 } - if method == "" { - method = "GET" - } + // 1.6 请求体 + reqBodyMap := map[string]any{"task_id": taskID} - // 2. 构建参数 - params := map[string]any{"id": taskID} - - // 3. 替换 URL 中的 {id} - finalURL := replaceURLParams(url, params) - - // 4. 构建请求体 - bodyCfg, _ := queryConfig["body"].(map[string]any) - body := buildParams(bodyCfg, params) - - // 5. 轮询 + // 2) 轮询请求 for { select { case <-ctx.Done(): @@ -183,21 +178,19 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a } var reqBody io.Reader - if method == "POST" && body != nil { - bs, _ := json.Marshal(body) + if method == "POST" { + bs, _ := json.Marshal(reqBodyMap) reqBody = bytes.NewReader(bs) } - req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody) + req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } - for k, v := range headers { - req.Header.Set(k, gconv.String(v)) - } - if req.Header.Get("Content-Type") == "" && reqBody != nil { - req.Header.Set("Content-Type", "application/json") + // 统一用 headMsg 注入请求头 + for hk, hv := range ParseHeadMsgHeaders(headMsg) { + req.Header.Set(hk, hv) } client := &http.Client{Timeout: 30 * time.Second} @@ -208,56 +201,54 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a continue } + raw, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + + g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { - all, _ := io.ReadAll(resp.Body) - resp.Body.Close() - g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all)) time.Sleep(time.Duration(interval) * time.Second) continue } var result map[string]any - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - resp.Body.Close() - g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err) - time.Sleep(time.Duration(interval) * time.Second) - continue - } - resp.Body.Close() + _ = json.Unmarshal(raw, &result) - status := gconv.String(result["status"]) - g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status) + statusVal := gjson.New(result).Get(statusPath).Val() + statusStr := gconv.String(statusVal) + g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal) - switch status { - case "succeeded": - return result, nil - case "failed", "expired": - return result, fmt.Errorf("任务失败: status=%s", status) - case "queued", "running": - time.Sleep(time.Duration(interval) * time.Second) - continue - default: - // 兼容没有 status 字段的情况,直接返回 + if matchStatus(statusStr, statusValues["succeeded"]) { + g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID) return result, nil } + + if matchStatus(statusStr, statusValues["failed"]) { + g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID) + return result, fmt.Errorf("任务失败") + } + + time.Sleep(time.Duration(interval) * time.Second) } } -// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key -func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any { - result := make(map[string]any, len(bodyCfg)+len(params)) - for k, v := range bodyCfg { - result[k] = v +func matchStatus(actual string, expected any) bool { + switch v := expected.(type) { + case string: + return actual == v + case []any: + for _, item := range v { + if actual == gconv.String(item) { + return true + } + } } - for k, v := range params { - result[k] = v - } - return result + return false } // replaceURLParams 替换 URL 中的 {key} func replaceURLParams(url string, params map[string]any) string { - re := regexp.MustCompile(`\{([^}]+)\}`) + re := regexp.MustCompile(`\{([^}]+)}`) return re.ReplaceAllStringFunc(url, func(s string) string { key := strings.Trim(s, "{}") if val, ok := params[key]; ok { @@ -267,18 +258,6 @@ func replaceURLParams(url string, params map[string]any) string { }) } -// replaceBodyParams 用 params 覆盖 body 中对应 key -func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any { - result := make(map[string]any) - for k, v := range bodyCfg { - result[k] = v - } - for k, v := range params { - result[k] = v - } - return result -} - // InjectCallbackURL 将回调地址注入到请求体中 func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any { if callbackURL == "" { diff --git a/dao/model_dao.go b/dao/model_dao.go index 50f379c..88a256d 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -60,18 +60,6 @@ func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows in // Get 按ID获取(带租户隔离,只查当前租户) func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { - //r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). - // OmitEmpty(). - // Where(entity.AsynchModelCol.Id, req.Id). - // Where(entity.AsynchModelCol.Creator, req.Creator). - // Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). - // Where(entity.AsynchModelCol.ModelName, req.ModelName). - // Fields(fields).One() - //if err != nil { - // return - //} - //err = r.Struct(&m) - var whereCondition strings.Builder var queryParams []interface{} if !g.IsEmpty(req.Id) { diff --git a/model/dto/model_dto.go b/model/dto/model_dto.go index 18151f8..fd972d1 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_dto.go @@ -95,10 +95,11 @@ type DeleteModelRes struct { // GetModelReq 获取模型配置详情 type GetModelReq struct { - g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"` - ID int64 `p:"id" json:"id,string" dc:"配置ID"` - Creator string `p:"creator" json:"creator" dc:"创建人"` - ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` + g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"` + ID int64 `p:"id" json:"id,string" dc:"配置ID"` + Creator string `p:"creator" json:"creator" dc:"创建人"` + IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"` + ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` } type GetModelRes struct { diff --git a/service/model/model_service.go b/service/model/model_service.go index 0942948..7bb985c 100644 --- a/service/model/model_service.go +++ b/service/model/model_service.go @@ -91,12 +91,14 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM if g.IsEmpty(req.ID) { req.Creator = user.UserName } - modelReq := new(entity.AsynchModel) - err = gconv.Struct(req, modelReq) - if err != nil { - return nil, err - } - model, err := dao.Model.Get(ctx, modelReq) + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{ + Id: req.ID, + Creator: user.UserName, + }, + ModelName: req.ModelName, + IsChatModel: req.IsChatModel, + }) if err != nil { return nil, err } diff --git a/service/task/task_service.go b/service/task/task_service.go index 5e05984..f87c469 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -179,7 +179,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi if err != nil || model == nil || model.QueryConfig == nil { continue } - result, err := util.PullTaskResult(ctx, t.TaskID, model.QueryConfig) + result, err := util.PullTaskResult(ctx, nil, model.QueryConfig, model.HeadMsg) if err != nil { g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err) continue diff --git a/service/task/worker.go b/service/task/worker.go index 0f2f8bd..7d49d29 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -122,9 +122,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo w.failTask(ctx, task, err.Error()) return } - // 拿到 task_id,启动轮询 - taskID := gjson.New(body).Get(model.ResponseBody).String() - body, err = util.PullTaskResult(ctx, taskID, model.QueryConfig) + body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg) if err != nil { w.failTask(ctx, task, err.Error()) return