refactor(model): 优化模型网关的数据解析和任务处理逻辑
This commit is contained in:
@@ -19,17 +19,20 @@ import (
|
|||||||
tgjson "github.com/tidwall/gjson"
|
tgjson "github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseAndValidate 解析并校验结果
|
// ParseAndValidate 解析模型响应,并返回标准格式
|
||||||
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
|
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
|
||||||
// 1) 解析 content 字符串为 rounds 数组
|
contentStr := gconv.String(raw[entity.ResponseBody])
|
||||||
contentVal, ok := raw[model.ResponseBody]
|
if strings.TrimSpace(contentStr) == "" {
|
||||||
if !ok {
|
return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
|
||||||
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
|
|
||||||
}
|
|
||||||
contentStr, ok := contentVal.(string)
|
|
||||||
if !ok || strings.TrimSpace(contentStr) == "" {
|
|
||||||
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
contentStr = strings.Map(func(r rune) rune {
|
||||||
|
if r < 32 && r != ' ' {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}, contentStr)
|
||||||
|
|
||||||
var arr []any
|
var arr []any
|
||||||
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
|
||||||
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
return raw, fmt.Errorf("JSON解析失败: %w", err)
|
||||||
@@ -38,17 +41,11 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
|
|||||||
return raw, fmt.Errorf("解析后数组为空")
|
return raw, fmt.Errorf("解析后数组为空")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 校验必填字段
|
for _, field := range model.RequiredFields {
|
||||||
if len(model.RequiredFields) > 0 {
|
|
||||||
for i, r := range arr {
|
for i, r := range arr {
|
||||||
round, ok := r.(map[string]any)
|
round, _ := r.(map[string]any)
|
||||||
if !ok {
|
if round != nil && gjson.New(round).Get(field).IsNil() {
|
||||||
continue
|
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
||||||
}
|
|
||||||
for _, field := range model.RequiredFields {
|
|
||||||
if gjson.New(round).Get(field).IsNil() {
|
|
||||||
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
|
|||||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||||
|
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
|
||||||
Fields(fields).One()
|
Fields(fields).One()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
|
|||||||
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
||||||
sql := `
|
sql := `
|
||||||
SELECT DISTINCT ON (model_name) *
|
SELECT DISTINCT ON (model_name) *
|
||||||
FROM asynch_models
|
FROM ` + public.TableNameModel + `
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
AND (? = '' OR model_name LIKE ?)
|
AND (? = '' OR model_name LIKE ?)
|
||||||
`
|
`
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/database/gdb"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
|
|||||||
|
|
||||||
// ClaimByID 按主键抢占,返回抢占后的任务
|
// ClaimByID 按主键抢占,返回抢占后的任务
|
||||||
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
||||||
|
// 1) 先查任务
|
||||||
var task entity.ModelGatewayTask
|
var task entity.ModelGatewayTask
|
||||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||||
r, err := tx.Model(public.TableNameTask).
|
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
||||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
One()
|
||||||
Limit(1).
|
|
||||||
LockUpdate().
|
|
||||||
One()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
|
||||||
}
|
|
||||||
if err := r.Struct(&task); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = tx.Model(public.TableNameTask).
|
|
||||||
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
|
||||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
|
||||||
OmitEmpty().
|
|
||||||
Update()
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if r.IsEmpty() {
|
||||||
|
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
||||||
|
}
|
||||||
|
if err = r.Struct(&task); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 改为执行中
|
||||||
|
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||||
|
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
||||||
|
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||||
|
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
|
||||||
|
OmitEmpty().
|
||||||
|
Update()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &task, nil
|
return &task, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,99 +4,98 @@ import "gitea.redpowerfuture.com/red-future/common/beans"
|
|||||||
|
|
||||||
type modelGatewayModelCol struct {
|
type modelGatewayModelCol struct {
|
||||||
beans.SQLBaseCol
|
beans.SQLBaseCol
|
||||||
ModelName string
|
ModelName string
|
||||||
ModelType string
|
ModelType string
|
||||||
BaseURL string
|
BaseURL string
|
||||||
HttpMethod string
|
HttpMethod string
|
||||||
HeadMsg string
|
HeadMsg string
|
||||||
FormJSON string
|
FormJSON string
|
||||||
RequestMapping string
|
RequestMapping string
|
||||||
ResponseMapping string
|
ResponseMapping string
|
||||||
ResponseBody string
|
RequiredFields string
|
||||||
ResponseTokenField string
|
IsPrivate string
|
||||||
RequiredFields string
|
IsChatModel string
|
||||||
IsPrivate string
|
CallMode string
|
||||||
IsChatModel string
|
ApiKey string
|
||||||
CallMode string
|
Enabled string
|
||||||
ApiKey string
|
MaxConcurrency string
|
||||||
Enabled string
|
TimeoutSeconds string
|
||||||
MaxConcurrency string
|
RetryTimes string
|
||||||
TimeoutSeconds string
|
AutoCleanSeconds string
|
||||||
RetryTimes string
|
IsOwner string
|
||||||
AutoCleanSeconds string
|
OperatorName string
|
||||||
IsOwner string
|
TokenConfig string
|
||||||
OperatorName string
|
ExtendMapping string
|
||||||
TokenConfig string
|
QueryConfig string
|
||||||
ExtendMapping string
|
StreamConfig string
|
||||||
QueryConfig string
|
FirstFrame string
|
||||||
StreamConfig string
|
LastFrame string
|
||||||
FirstFrame string
|
MaxTokens string
|
||||||
LastFrame string
|
|
||||||
MaxTokens string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ModelGatewayModelCol = modelGatewayModelCol{
|
var ModelGatewayModelCol = modelGatewayModelCol{
|
||||||
SQLBaseCol: beans.DefSQLBaseCol,
|
SQLBaseCol: beans.DefSQLBaseCol,
|
||||||
ModelName: "model_name",
|
ModelName: "model_name",
|
||||||
ModelType: "model_type",
|
ModelType: "model_type",
|
||||||
BaseURL: "base_url",
|
BaseURL: "base_url",
|
||||||
HttpMethod: "http_method",
|
HttpMethod: "http_method",
|
||||||
HeadMsg: "head_msg",
|
HeadMsg: "head_msg",
|
||||||
FormJSON: "form_json",
|
FormJSON: "form_json",
|
||||||
RequestMapping: "request_mapping",
|
RequestMapping: "request_mapping",
|
||||||
ResponseMapping: "response_mapping",
|
ResponseMapping: "response_mapping",
|
||||||
ResponseBody: "response_body",
|
RequiredFields: "required_fields",
|
||||||
ResponseTokenField: "response_token_field",
|
IsPrivate: "is_private",
|
||||||
RequiredFields: "required_fields",
|
IsChatModel: "is_chat_model",
|
||||||
IsPrivate: "is_private",
|
CallMode: "call_mode",
|
||||||
IsChatModel: "is_chat_model",
|
ApiKey: "api_key",
|
||||||
CallMode: "call_mode",
|
Enabled: "enabled",
|
||||||
ApiKey: "api_key",
|
MaxConcurrency: "max_concurrency",
|
||||||
Enabled: "enabled",
|
TimeoutSeconds: "timeout_seconds",
|
||||||
MaxConcurrency: "max_concurrency",
|
RetryTimes: "retry_times",
|
||||||
TimeoutSeconds: "timeout_seconds",
|
AutoCleanSeconds: "auto_clean_seconds",
|
||||||
RetryTimes: "retry_times",
|
IsOwner: "is_owner",
|
||||||
AutoCleanSeconds: "auto_clean_seconds",
|
OperatorName: "operator_name",
|
||||||
IsOwner: "is_owner",
|
TokenConfig: "token_config",
|
||||||
OperatorName: "operator_name",
|
ExtendMapping: "extend_mapping",
|
||||||
TokenConfig: "token_config",
|
QueryConfig: "query_config",
|
||||||
ExtendMapping: "extend_mapping",
|
StreamConfig: "stream_config",
|
||||||
QueryConfig: "query_config",
|
FirstFrame: "first_frame",
|
||||||
StreamConfig: "stream_config",
|
LastFrame: "last_frame",
|
||||||
FirstFrame: "first_frame",
|
MaxTokens: "max_tokens",
|
||||||
LastFrame: "last_frame",
|
|
||||||
MaxTokens: "max_tokens",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelGatewayModel struct {
|
type ModelGatewayModel struct {
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
ModelName string `orm:"model_name" json:"modelName"`
|
ModelName string `orm:"model_name" json:"modelName"`
|
||||||
ModelType int `orm:"model_type" json:"modelType"`
|
ModelType int `orm:"model_type" json:"modelType"`
|
||||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||||
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
|
||||||
Form []map[string]any `orm:"form_json" json:"form"`
|
Form []map[string]any `orm:"form_json" json:"form"`
|
||||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||||
ResponseBody string `orm:"response_body" json:"responseBody"`
|
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
|
||||||
ResponseTokenField string `orm:"response_token_field" json:"tokenField"`
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
|
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
CallMode *int `orm:"call_mode" json:"callMode"`
|
||||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||||
CallMode *int `orm:"call_mode" json:"callMode"`
|
Enabled *int `orm:"enabled" json:"enabled"`
|
||||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||||
Enabled *int `orm:"enabled" json:"enabled"`
|
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
IsOwner *int `orm:"is_owner" json:"isOwner"`
|
||||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
IsOwner *int `orm:"is_owner" json:"isOwner"`
|
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
||||||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
||||||
FirstFrame string `orm:"first_frame" json:"firstFrame"`
|
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
|
||||||
LastFrame string `orm:"last_frame" json:"lastFrame"`
|
|
||||||
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
ResponseBody = "content" //返回主体(必填)
|
||||||
|
TotalTokens = "total_tokens" //总token数
|
||||||
|
)
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
|
|||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
IsChatModel: req.IsChatModel,
|
IsChatModel: req.IsChatModel,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil || model == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &dto.GetModelRes{
|
return &dto.GetModelRes{
|
||||||
|
|||||||
@@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
// 5) 获取任务信息
|
// 5) 抢占任务:改为执行中
|
||||||
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
|
rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: id},
|
||||||
|
State: public.TaskStatusRunning,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if rows == 0 {
|
||||||
|
return nil, fmt.Errorf("任务不存在: id=%d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6) 查询任务信息
|
||||||
|
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: id},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5) 创建成功后立即异步尝试执行当前任务
|
// 7) 创建成功后立即异步尝试执行当前任务
|
||||||
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
||||||
|
|
||||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||||
|
|||||||
@@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
// ============================================
|
// ============================================
|
||||||
// 2) 调用模型
|
// 2) 调用模型
|
||||||
// ============================================
|
// ============================================
|
||||||
switch {
|
for attempt := 0; ; attempt++ {
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
if attempt > 0 {
|
||||||
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
|
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
|
||||||
if streamErr != nil {
|
time.Sleep(time.Duration(attempt) * time.Second)
|
||||||
w.failTask(ctx, task, startTime, streamErr.Error())
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||||||
|
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
|
||||||
|
if streamErr != nil {
|
||||||
|
err = streamErr
|
||||||
|
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||||
|
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||||
|
result, err = w.callModel(ctx, task, model, body)
|
||||||
|
if err == nil {
|
||||||
|
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
result, err = w.callModel(ctx, task, model, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "Timeout") {
|
||||||
|
w.failTask(ctx, task, startTime, err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
|
||||||
result, err = w.callModel(ctx, task, model, body)
|
|
||||||
if err == nil {
|
|
||||||
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
result, err = w.callModel(ctx, task, model, body)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
w.failTask(ctx, task, startTime, err.Error())
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================
|
// ============================================
|
||||||
@@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 2. 拿到 task_id
|
// 2. 拿到 task_id
|
||||||
taskID := gjson.New(body).Get(model.ResponseBody).String()
|
taskID := gjson.New(body).Get(entity.ResponseBody).String()
|
||||||
|
|
||||||
// 3. 创建等待通道
|
// 3. 创建等待通道
|
||||||
ch := make(chan asyncResult, 1)
|
ch := make(chan asyncResult, 1)
|
||||||
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
|
|||||||
|
|
||||||
// parseAndRetry 解析模型返回结果,并重试
|
// parseAndRetry 解析模型返回结果,并重试
|
||||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||||
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
// 1) 响应映射
|
// 1) 响应映射
|
||||||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
|
||||||
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 先存 token 到数据库,防止后续失败丢失
|
// 2) 存 token
|
||||||
if _, ok := mapped[model.ResponseTokenField]; ok {
|
if _, ok := mapped[entity.TotalTokens]; ok {
|
||||||
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
|
task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
|
||||||
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||||
ExpendTokens: task.ExpendTokens,
|
ExpendTokens: task.ExpendTokens,
|
||||||
})
|
})
|
||||||
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
lastErr = err
|
||||||
case public.BuildTypeStruct:
|
case public.BuildTypeStruct:
|
||||||
parsed = util.ParseStructResult(mapped, model.ResponseBody)
|
return util.ParseStructResult(mapped, entity.ResponseBody), nil
|
||||||
return parsed, nil
|
|
||||||
default:
|
default:
|
||||||
return mapped, nil
|
return mapped, nil
|
||||||
}
|
}
|
||||||
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||||
|
|
||||||
if attempt == maxRetry {
|
if attempt == maxRetry {
|
||||||
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err)
|
return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 重新调模型(直接调,不走缓存)
|
// 4) 拼接错误信息到请求体,重调模型
|
||||||
task.RetryCount++
|
task.RetryCount++
|
||||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||||
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
|
|
||||||
|
|
||||||
|
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
|
||||||
|
rawData, callErr := InvokeModel(ctx, model, body)
|
||||||
if callErr != nil {
|
if callErr != nil {
|
||||||
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5) 解析原始响应,覆盖 body 进入下一轮
|
|
||||||
var rawResp map[string]any
|
var rawResp map[string]any
|
||||||
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
||||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
|||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// injectErrorMessage 将错误信息拼接到 user 消息中
|
||||||
|
func injectErrorMessage(payload map[string]any, err error) map[string]any {
|
||||||
|
if err == nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, _ := payload["messages"].([]any)
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
errMsg := fmt.Sprintf("\n\n【上一轮输出错误,请修正】%s", err.Error())
|
||||||
|
|
||||||
|
// 找到最后一个 role=user 的消息,追加错误提示
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
msg, ok := messages[i].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gconv.String(msg["role"]) != "user" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := msg["content"].(type) {
|
||||||
|
case string:
|
||||||
|
msg["content"] = c + errMsg
|
||||||
|
case []any:
|
||||||
|
msg["content"] = append(c, map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": errMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
// InvokeModel 调用模型服务,返回二进制结果
|
// InvokeModel 调用模型服务,返回二进制结果
|
||||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||||
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user