refactor(model): 优化模型网关的数据解析和任务处理逻辑

This commit is contained in:
2026-06-17 14:34:48 +08:00
parent b3b111995e
commit fddaf36f48
7 changed files with 231 additions and 166 deletions

View File

@@ -19,17 +19,20 @@ import (
tgjson "github.com/tidwall/gjson" tgjson "github.com/tidwall/gjson"
) )
// ParseAndValidate 解析并校验结果 // ParseAndValidate 解析模型响应,并返回标准格式
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组 contentStr := gconv.String(raw[entity.ResponseBody])
contentVal, ok := raw[model.ResponseBody] if strings.TrimSpace(contentStr) == "" {
if !ok { return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
} }
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" { contentStr = strings.Map(func(r rune) rune {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody) if r < 32 && r != ' ' {
return -1
} }
return r
}, contentStr)
var arr []any var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err) return raw, fmt.Errorf("JSON解析失败: %w", err)
@@ -38,20 +41,14 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
return raw, fmt.Errorf("解析后数组为空") return raw, fmt.Errorf("解析后数组为空")
} }
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields { for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() { for i, r := range arr {
round, _ := r.(map[string]any)
if round != nil && gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
} }
} }
} }
}
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
} }

View File

@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
Where(entity.ModelGatewayModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) { func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := ` sql := `
SELECT DISTINCT ON (model_name) * SELECT DISTINCT ON (model_name) *
FROM asynch_models FROM ` + public.TableNameModel + `
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?) AND (? = '' OR model_name LIKE ?)
` `

View File

@@ -7,7 +7,6 @@ import (
"model-gateway/model/entity" "model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb" "gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
// ClaimByID 按主键抢占,返回抢占后的任务 // ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
// 1) 先查任务
var task entity.ModelGatewayTask var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id). Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One() One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if r.IsEmpty() {
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err = r.Struct(&task); err != nil {
return nil, err
}
// 2) 改为执行中
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
OmitEmpty().
Update()
if err != nil {
return nil, err
}
return &task, nil return &task, nil
} }

View File

@@ -12,8 +12,6 @@ type modelGatewayModelCol struct {
FormJSON string FormJSON string
RequestMapping string RequestMapping string
ResponseMapping string ResponseMapping string
ResponseBody string
ResponseTokenField string
RequiredFields string RequiredFields string
IsPrivate string IsPrivate string
IsChatModel string IsChatModel string
@@ -45,8 +43,6 @@ var ModelGatewayModelCol = modelGatewayModelCol{
FormJSON: "form_json", FormJSON: "form_json",
RequestMapping: "request_mapping", RequestMapping: "request_mapping",
ResponseMapping: "response_mapping", ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
RequiredFields: "required_fields", RequiredFields: "required_fields",
IsPrivate: "is_private", IsPrivate: "is_private",
IsChatModel: "is_chat_model", IsChatModel: "is_chat_model",
@@ -78,8 +74,6 @@ type ModelGatewayModel struct {
Form []map[string]any `orm:"form_json" json:"form"` Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"tokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"` RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"` IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
@@ -100,3 +94,8 @@ type ModelGatewayModel struct {
LastFrame string `orm:"last_frame" json:"lastFrame"` LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"` MaxTokens int `orm:"max_tokens" json:"maxTokens"`
} }
const (
ResponseBody = "content" //返回主体(必填)
TotalTokens = "total_tokens" //总token数
)

View File

@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
ModelName: req.ModelName, ModelName: req.ModelName,
IsChatModel: req.IsChatModel, IsChatModel: req.IsChatModel,
}) })
if err != nil { if err != nil || model == nil {
return nil, err return nil, err
} }
return &dto.GetModelRes{ return &dto.GetModelRes{

View File

@@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}, },
}) })
// 5) 获取任务信息 // 5) 抢占任务:改为执行中
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id) rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
State: public.TaskStatusRunning,
})
if err != nil {
return nil, err
}
if rows == 0 {
return nil, fmt.Errorf("任务不存在: id=%d", id)
}
// 6) 查询任务信息
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 5) 创建成功后立即异步尝试执行当前任务 // 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -67,12 +67,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
// ============================================ // ============================================
// 2) 调用模型 // 2) 调用模型
// ============================================ // ============================================
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch { switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream: case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body) rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil { if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error()) err = streamErr
return g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
} }
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync: case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
@@ -83,10 +90,17 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
default: default:
result, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
} }
if err != nil {
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================ // ============================================
// 3) 缓存临时文件 // 3) 缓存临时文件
@@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate
return nil, err return nil, err
} }
// 2. 拿到 task_id // 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String() taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道 // 3. 创建等待通道
ch := make(chan asyncResult, 1) ch := make(chan asyncResult, 1)
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 1) 响应映射 // 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body) mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil { if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err) return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue continue
} }
// 2) 存 token 到数据库,防止后续失败丢失 // 2) 存 token
if _, ok := mapped[model.ResponseTokenField]; ok { if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField]) task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens, ExpendTokens: task.ExpendTokens,
}) })
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil { if err == nil {
return parsed, nil return parsed, nil
} }
lastErr = err
case public.BuildTypeStruct: case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody) return util.ParseStructResult(mapped, entity.ResponseBody), nil
return parsed, nil
default: default:
return mapped, nil return mapped, nil
} }
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err) return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
} }
// 4) 重新调模型(直接调,不走缓存) // 4) 拼接错误信息到请求体,重调模型
task.RetryCount++ task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task) _, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil { if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue continue
} }
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil { if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue continue
} }
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil return body, nil
} }
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果 // InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {