refactor(model): 优化模型网关的数据解析和任务处理逻辑

This commit is contained in:
2026-06-17 14:34:48 +08:00
parent b3b111995e
commit fddaf36f48
7 changed files with 231 additions and 166 deletions

View File

@@ -19,17 +19,20 @@ import (
tgjson "github.com/tidwall/gjson"
)
// ParseAndValidate 解析并校验结果
// ParseAndValidate 解析模型响应,并返回标准格式
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody]
if !ok {
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
}
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
contentStr := gconv.String(raw[entity.ResponseBody])
if strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
}
contentStr = strings.Map(func(r rune) rune {
if r < 32 && r != ' ' {
return -1
}
return r
}, contentStr)
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err)
@@ -38,17 +41,11 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
return raw, fmt.Errorf("解析后数组为空")
}
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for _, field := range model.RequiredFields {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
round, _ := r.(map[string]any)
if round != nil && gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}

View File

@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
Fields(fields).One()
if err != nil {
return nil, err
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
FROM ` + public.TableNameModel + `
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
`

View File

@@ -7,7 +7,6 @@ import (
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
// ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
// 1) 先查任务
var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
One()
if err != nil {
return nil, err
}
if r.IsEmpty() {
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err = r.Struct(&task); err != nil {
return nil, err
}
// 2) 改为执行中
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
OmitEmpty().
Update()
if err != nil {
return nil, err
}
return &task, nil
}

View File

@@ -4,99 +4,98 @@ import "gitea.redpowerfuture.com/red-future/common/beans"
type modelGatewayModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
RequiredFields string
IsPrivate string
IsChatModel string
CallMode string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
MaxTokens string
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
RequiredFields string
IsPrivate string
IsChatModel string
CallMode string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
MaxTokens string
}
var ModelGatewayModelCol = modelGatewayModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
RequiredFields: "required_fields",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
CallMode: "call_mode",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
MaxTokens: "max_tokens",
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
RequiredFields: "required_fields",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
CallMode: "call_mode",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
MaxTokens: "max_tokens",
}
type ModelGatewayModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"tokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
CallMode *int `orm:"call_mode" json:"callMode"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `orm:"is_owner" json:"isOwner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
CallMode *int `orm:"call_mode" json:"callMode"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `orm:"is_owner" json:"isOwner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
}
const (
ResponseBody = "content" //返回主体(必填)
TotalTokens = "total_tokens" //总token数
)

View File

@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
ModelName: req.ModelName,
IsChatModel: req.IsChatModel,
})
if err != nil {
if err != nil || model == nil {
return nil, err
}
return &dto.GetModelRes{

View File

@@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
},
})
// 5) 获取任务信息
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
// 5) 抢占任务:改为执行中
rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
State: public.TaskStatusRunning,
})
if err != nil {
return nil, err
}
if rows == 0 {
return nil, fmt.Errorf("任务不存在: id=%d", id)
}
// 6) 查询任务信息
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
})
if err != nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
// 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
// ============================================
// 2) 调用模型
// ============================================
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error())
for attempt := 0; ; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
time.Sleep(time.Duration(attempt) * time.Second)
}
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
err = streamErr
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") {
w.failTask(ctx, task, startTime, err.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
}
// ============================================
@@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate
return nil, err
}
// 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String()
taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道
ch := make(chan asyncResult, 1)
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue
}
// 2) 存 token 到数据库,防止后续失败丢失
if _, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
// 2) 存 token
if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil {
return parsed, nil
}
lastErr = err
case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody)
return parsed, nil
return util.ParseStructResult(mapped, entity.ResponseBody), nil
default:
return mapped, nil
}
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
}
// 4) 重新调模型(直接调,不走缓存)
// 4) 拼接错误信息到请求体,重调模型
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
}
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil {
if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil
}
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {