refactor(util): 重构映射工具函数并优化异步任务轮询逻辑
This commit is contained in:
@@ -52,21 +52,24 @@ func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReverseMap 映射 payload 到 mapping
|
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
||||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
// head_msg 格式示例:
|
||||||
jsonObj := gjson.New("{}")
|
//
|
||||||
for path, defaultValue := range mapping {
|
// {
|
||||||
// 从 payload 取对应路径的值
|
// "Authorization": "Bearer xxx",
|
||||||
val := gjson.New(payload).Get(path)
|
// "Content-Type": "application/json",
|
||||||
if !val.IsNil() {
|
// "X-Api-App-Id": "5147401364",
|
||||||
// payload 有值,用它
|
// "X-Api-Access-Key": "VCqRX7..."
|
||||||
_ = jsonObj.Set(path, val.Val())
|
// }
|
||||||
} else if !g.IsEmpty(defaultValue) {
|
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
||||||
// payload 没值,用默认值
|
if len(headMsg) == 0 {
|
||||||
_ = jsonObj.Set(path, defaultValue)
|
return nil
|
||||||
}
|
}
|
||||||
|
out := make(map[string]string, len(headMsg))
|
||||||
|
for k, v := range headMsg {
|
||||||
|
out[k] = gconv.String(v)
|
||||||
}
|
}
|
||||||
return jsonObj.Map()
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// MapResponsePayload 映射模型响应为标准格式
|
// MapResponsePayload 映射模型响应为标准格式
|
||||||
@@ -106,26 +109,6 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
|
|||||||
return mapped, nil
|
return mapped, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
|
|
||||||
// head_msg 格式示例:
|
|
||||||
//
|
|
||||||
// {
|
|
||||||
// "Authorization": "Bearer xxx",
|
|
||||||
// "Content-Type": "application/json",
|
|
||||||
// "X-Api-App-Id": "5147401364",
|
|
||||||
// "X-Api-Access-Key": "VCqRX7..."
|
|
||||||
// }
|
|
||||||
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
|
|
||||||
if len(headMsg) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make(map[string]string, len(headMsg))
|
|
||||||
for k, v := range headMsg {
|
|
||||||
out[k] = gconv.String(v)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelBody 获取数据库中保存的模型信息
|
// GetModelBody 获取数据库中保存的模型信息
|
||||||
func GetModelBody(v map[string]any) map[string]any {
|
func GetModelBody(v map[string]any) map[string]any {
|
||||||
if v == nil {
|
if v == nil {
|
||||||
@@ -149,32 +132,44 @@ func BodyToQuery(payload map[string]any) (url.Values, error) {
|
|||||||
return q, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PullTaskResult 轮询查询任务结果直到完成
|
// PullTaskResult 轮询查询异步任务结果直到完成
|
||||||
func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]any) (map[string]any, error) {
|
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
|
||||||
// 1. 解析配置
|
// 1) 解析配置
|
||||||
url := gconv.String(queryConfig["url"])
|
// 1.1 提取 taskID
|
||||||
|
taskIDPath := gconv.String(queryConfig["task_id"])
|
||||||
|
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
|
||||||
|
if taskID == "" {
|
||||||
|
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
|
||||||
|
|
||||||
|
// 1.2 请求地址,替换 {id}
|
||||||
|
queryUrl := gconv.String(queryConfig["url"])
|
||||||
|
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
|
||||||
|
|
||||||
|
// 1.3 请求方式
|
||||||
method := gconv.String(queryConfig["method"])
|
method := gconv.String(queryConfig["method"])
|
||||||
headers, _ := queryConfig["headers"].(map[string]any)
|
if method == "" {
|
||||||
|
method = "GET"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1.4 状态判断配置
|
||||||
|
statusPath := gconv.String(queryConfig["status_path"])
|
||||||
|
statusValues, _ := queryConfig["status_values"].(map[string]any)
|
||||||
|
if statusPath == "" {
|
||||||
|
statusPath = "status"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1.5 轮询间隔
|
||||||
interval := gconv.Int(queryConfig["interval_seconds"])
|
interval := gconv.Int(queryConfig["interval_seconds"])
|
||||||
if interval <= 0 {
|
if interval <= 0 {
|
||||||
interval = 2
|
interval = 2
|
||||||
}
|
}
|
||||||
|
|
||||||
if method == "" {
|
// 1.6 请求体
|
||||||
method = "GET"
|
reqBodyMap := map[string]any{"task_id": taskID}
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 构建参数
|
// 2) 轮询请求
|
||||||
params := map[string]any{"id": taskID}
|
|
||||||
|
|
||||||
// 3. 替换 URL 中的 {id}
|
|
||||||
finalURL := replaceURLParams(url, params)
|
|
||||||
|
|
||||||
// 4. 构建请求体
|
|
||||||
bodyCfg, _ := queryConfig["body"].(map[string]any)
|
|
||||||
body := buildParams(bodyCfg, params)
|
|
||||||
|
|
||||||
// 5. 轮询
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -183,21 +178,19 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a
|
|||||||
}
|
}
|
||||||
|
|
||||||
var reqBody io.Reader
|
var reqBody io.Reader
|
||||||
if method == "POST" && body != nil {
|
if method == "POST" {
|
||||||
bs, _ := json.Marshal(body)
|
bs, _ := json.Marshal(reqBodyMap)
|
||||||
reqBody = bytes.NewReader(bs)
|
reqBody = bytes.NewReader(bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, method, finalURL, reqBody)
|
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range headers {
|
// 统一用 headMsg 注入请求头
|
||||||
req.Header.Set(k, gconv.String(v))
|
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
|
||||||
}
|
req.Header.Set(hk, hv)
|
||||||
if req.Header.Get("Content-Type") == "" && reqBody != nil {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &http.Client{Timeout: 30 * time.Second}
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
@@ -208,56 +201,54 @@ func PullTaskResult(ctx context.Context, taskID string, queryConfig map[string]a
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
all, _ := io.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
g.Log().Warningf(ctx, "[PullTaskResult] 请求异常 taskID=%s status=%d body=%s", taskID, resp.StatusCode, string(all))
|
|
||||||
time.Sleep(time.Duration(interval) * time.Second)
|
time.Sleep(time.Duration(interval) * time.Second)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var result map[string]any
|
var result map[string]any
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
_ = json.Unmarshal(raw, &result)
|
||||||
resp.Body.Close()
|
|
||||||
g.Log().Warningf(ctx, "[PullTaskResult] 解析失败 taskID=%s err=%v", taskID, err)
|
|
||||||
time.Sleep(time.Duration(interval) * time.Second)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
status := gconv.String(result["status"])
|
statusVal := gjson.New(result).Get(statusPath).Val()
|
||||||
g.Log().Infof(ctx, "[PullTaskResult] 轮询 taskID=%s status=%s", taskID, status)
|
statusStr := gconv.String(statusVal)
|
||||||
|
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
|
||||||
|
|
||||||
switch status {
|
if matchStatus(statusStr, statusValues["succeeded"]) {
|
||||||
case "succeeded":
|
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
|
||||||
return result, nil
|
|
||||||
case "failed", "expired":
|
|
||||||
return result, fmt.Errorf("任务失败: status=%s", status)
|
|
||||||
case "queued", "running":
|
|
||||||
time.Sleep(time.Duration(interval) * time.Second)
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
// 兼容没有 status 字段的情况,直接返回
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if matchStatus(statusStr, statusValues["failed"]) {
|
||||||
|
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
|
||||||
|
return result, fmt.Errorf("任务失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(interval) * time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildParams 构建请求参数,用 params 覆盖 bodyCfg 中对应 key
|
func matchStatus(actual string, expected any) bool {
|
||||||
func buildParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
switch v := expected.(type) {
|
||||||
result := make(map[string]any, len(bodyCfg)+len(params))
|
case string:
|
||||||
for k, v := range bodyCfg {
|
return actual == v
|
||||||
result[k] = v
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if actual == gconv.String(item) {
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
for k, v := range params {
|
|
||||||
result[k] = v
|
|
||||||
}
|
}
|
||||||
return result
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// replaceURLParams 替换 URL 中的 {key}
|
// replaceURLParams 替换 URL 中的 {key}
|
||||||
func replaceURLParams(url string, params map[string]any) string {
|
func replaceURLParams(url string, params map[string]any) string {
|
||||||
re := regexp.MustCompile(`\{([^}]+)\}`)
|
re := regexp.MustCompile(`\{([^}]+)}`)
|
||||||
return re.ReplaceAllStringFunc(url, func(s string) string {
|
return re.ReplaceAllStringFunc(url, func(s string) string {
|
||||||
key := strings.Trim(s, "{}")
|
key := strings.Trim(s, "{}")
|
||||||
if val, ok := params[key]; ok {
|
if val, ok := params[key]; ok {
|
||||||
@@ -267,18 +258,6 @@ func replaceURLParams(url string, params map[string]any) string {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// replaceBodyParams 用 params 覆盖 body 中对应 key
|
|
||||||
func replaceBodyParams(bodyCfg map[string]any, params map[string]any) map[string]any {
|
|
||||||
result := make(map[string]any)
|
|
||||||
for k, v := range bodyCfg {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range params {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// InjectCallbackURL 将回调地址注入到请求体中
|
// InjectCallbackURL 将回调地址注入到请求体中
|
||||||
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
|
||||||
if callbackURL == "" {
|
if callbackURL == "" {
|
||||||
|
|||||||
@@ -60,18 +60,6 @@ func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows in
|
|||||||
|
|
||||||
// Get 按ID获取(带租户隔离,只查当前租户)
|
// Get 按ID获取(带租户隔离,只查当前租户)
|
||||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
//r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
|
||||||
// OmitEmpty().
|
|
||||||
// Where(entity.AsynchModelCol.Id, req.Id).
|
|
||||||
// Where(entity.AsynchModelCol.Creator, req.Creator).
|
|
||||||
// Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
|
||||||
// Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
|
||||||
// Fields(fields).One()
|
|
||||||
//if err != nil {
|
|
||||||
// return
|
|
||||||
//}
|
|
||||||
//err = r.Struct(&m)
|
|
||||||
|
|
||||||
var whereCondition strings.Builder
|
var whereCondition strings.Builder
|
||||||
var queryParams []interface{}
|
var queryParams []interface{}
|
||||||
if !g.IsEmpty(req.Id) {
|
if !g.IsEmpty(req.Id) {
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ type GetModelReq struct {
|
|||||||
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
|
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
|
||||||
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
|
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
|
||||||
Creator string `p:"creator" json:"creator" dc:"创建人"`
|
Creator string `p:"creator" json:"creator" dc:"创建人"`
|
||||||
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"`
|
||||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -91,12 +91,14 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
|
|||||||
if g.IsEmpty(req.ID) {
|
if g.IsEmpty(req.ID) {
|
||||||
req.Creator = user.UserName
|
req.Creator = user.UserName
|
||||||
}
|
}
|
||||||
modelReq := new(entity.AsynchModel)
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
err = gconv.Struct(req, modelReq)
|
SQLBaseDO: beans.SQLBaseDO{
|
||||||
if err != nil {
|
Id: req.ID,
|
||||||
return nil, err
|
Creator: user.UserName,
|
||||||
}
|
},
|
||||||
model, err := dao.Model.Get(ctx, modelReq)
|
ModelName: req.ModelName,
|
||||||
|
IsChatModel: req.IsChatModel,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
|
|||||||
if err != nil || model == nil || model.QueryConfig == nil {
|
if err != nil || model == nil || model.QueryConfig == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result, err := util.PullTaskResult(ctx, t.TaskID, model.QueryConfig)
|
result, err := util.PullTaskResult(ctx, nil, model.QueryConfig, model.HeadMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
|
g.Log().Warningf(ctx, "[轮询] 查询失败 taskID=%s err=%v", t.TaskID, err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -122,9 +122,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
|
|||||||
w.failTask(ctx, task, err.Error())
|
w.failTask(ctx, task, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 拿到 task_id,启动轮询
|
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
|
||||||
taskID := gjson.New(body).Get(model.ResponseBody).String()
|
|
||||||
body, err = util.PullTaskResult(ctx, taskID, model.QueryConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.failTask(ctx, task, err.Error())
|
w.failTask(ctx, task, err.Error())
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user