refactor(model): 优化模型网关的数据解析和任务处理逻辑

This commit is contained in:
2026-06-17 14:34:48 +08:00
parent b3b111995e
commit fddaf36f48
7 changed files with 231 additions and 166 deletions

View File

@@ -19,17 +19,20 @@ import (
tgjson "github.com/tidwall/gjson" tgjson "github.com/tidwall/gjson"
) )
// ParseAndValidate 解析并校验结果 // ParseAndValidate 解析模型响应,并返回标准格式
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组 contentStr := gconv.String(raw[entity.ResponseBody])
contentVal, ok := raw[model.ResponseBody] if strings.TrimSpace(contentStr) == "" {
if !ok { return raw, fmt.Errorf("字段 %s 为空", entity.ResponseBody)
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
}
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
} }
contentStr = strings.Map(func(r rune) rune {
if r < 32 && r != ' ' {
return -1
}
return r
}, contentStr)
var arr []any var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err) return raw, fmt.Errorf("JSON解析失败: %w", err)
@@ -38,17 +41,11 @@ func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[
return raw, fmt.Errorf("解析后数组为空") return raw, fmt.Errorf("解析后数组为空")
} }
// 2) 校验必填字段 for _, field := range model.RequiredFields {
if len(model.RequiredFields) > 0 {
for i, r := range arr { for i, r := range arr {
round, ok := r.(map[string]any) round, _ := r.(map[string]any)
if !ok { if round != nil && gjson.New(round).Get(field).IsNil() {
continue return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
} }
} }
} }

View File

@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
Where(entity.ModelGatewayModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) { func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := ` sql := `
SELECT DISTINCT ON (model_name) * SELECT DISTINCT ON (model_name) *
FROM asynch_models FROM ` + public.TableNameModel + `
WHERE deleted_at IS NULL WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?) AND (? = '' OR model_name LIKE ?)
` `

View File

@@ -7,7 +7,6 @@ import (
"model-gateway/model/entity" "model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb" "gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
// ClaimByID 按主键抢占,返回抢占后的任务 // ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
// 1) 先查任务
var task entity.ModelGatewayTask var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
r, err := tx.Model(public.TableNameTask). Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.Id, id). Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). One()
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if r.IsEmpty() {
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err = r.Struct(&task); err != nil {
return nil, err
}
// 2) 改为执行中
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
OmitEmpty().
Update()
if err != nil {
return nil, err
}
return &task, nil return &task, nil
} }

View File

@@ -4,99 +4,98 @@ import "gitea.redpowerfuture.com/red-future/common/beans"
type modelGatewayModelCol struct { type modelGatewayModelCol struct {
beans.SQLBaseCol beans.SQLBaseCol
ModelName string ModelName string
ModelType string ModelType string
BaseURL string BaseURL string
HttpMethod string HttpMethod string
HeadMsg string HeadMsg string
FormJSON string FormJSON string
RequestMapping string RequestMapping string
ResponseMapping string ResponseMapping string
ResponseBody string RequiredFields string
ResponseTokenField string IsPrivate string
RequiredFields string IsChatModel string
IsPrivate string CallMode string
IsChatModel string ApiKey string
CallMode string Enabled string
ApiKey string MaxConcurrency string
Enabled string TimeoutSeconds string
MaxConcurrency string RetryTimes string
TimeoutSeconds string AutoCleanSeconds string
RetryTimes string IsOwner string
AutoCleanSeconds string OperatorName string
IsOwner string TokenConfig string
OperatorName string ExtendMapping string
TokenConfig string QueryConfig string
ExtendMapping string StreamConfig string
QueryConfig string FirstFrame string
StreamConfig string LastFrame string
FirstFrame string MaxTokens string
LastFrame string
MaxTokens string
} }
var ModelGatewayModelCol = modelGatewayModelCol{ var ModelGatewayModelCol = modelGatewayModelCol{
SQLBaseCol: beans.DefSQLBaseCol, SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name", ModelName: "model_name",
ModelType: "model_type", ModelType: "model_type",
BaseURL: "base_url", BaseURL: "base_url",
HttpMethod: "http_method", HttpMethod: "http_method",
HeadMsg: "head_msg", HeadMsg: "head_msg",
FormJSON: "form_json", FormJSON: "form_json",
RequestMapping: "request_mapping", RequestMapping: "request_mapping",
ResponseMapping: "response_mapping", ResponseMapping: "response_mapping",
ResponseBody: "response_body", RequiredFields: "required_fields",
ResponseTokenField: "response_token_field", IsPrivate: "is_private",
RequiredFields: "required_fields", IsChatModel: "is_chat_model",
IsPrivate: "is_private", CallMode: "call_mode",
IsChatModel: "is_chat_model", ApiKey: "api_key",
CallMode: "call_mode", Enabled: "enabled",
ApiKey: "api_key", MaxConcurrency: "max_concurrency",
Enabled: "enabled", TimeoutSeconds: "timeout_seconds",
MaxConcurrency: "max_concurrency", RetryTimes: "retry_times",
TimeoutSeconds: "timeout_seconds", AutoCleanSeconds: "auto_clean_seconds",
RetryTimes: "retry_times", IsOwner: "is_owner",
AutoCleanSeconds: "auto_clean_seconds", OperatorName: "operator_name",
IsOwner: "is_owner", TokenConfig: "token_config",
OperatorName: "operator_name", ExtendMapping: "extend_mapping",
TokenConfig: "token_config", QueryConfig: "query_config",
ExtendMapping: "extend_mapping", StreamConfig: "stream_config",
QueryConfig: "query_config", FirstFrame: "first_frame",
StreamConfig: "stream_config", LastFrame: "last_frame",
FirstFrame: "first_frame", MaxTokens: "max_tokens",
LastFrame: "last_frame",
MaxTokens: "max_tokens",
} }
type ModelGatewayModel struct { type ModelGatewayModel struct {
beans.SQLBaseDO `orm:",inline"` beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"` ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"` ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"` BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"` HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"` HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"` Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"` RequiredFields []string `orm:"required_fields" json:"requiredFields"`
ResponseTokenField string `orm:"response_token_field" json:"tokenField"` IsPrivate *int `orm:"is_private" json:"isPrivate"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
IsPrivate *int `orm:"is_private" json:"isPrivate"` CallMode *int `orm:"call_mode" json:"callMode"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` ApiKey string `orm:"api_key" json:"apiKey"`
CallMode *int `orm:"call_mode" json:"callMode"` Enabled *int `orm:"enabled" json:"enabled"`
ApiKey string `orm:"api_key" json:"apiKey"` MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
Enabled *int `orm:"enabled" json:"enabled"` TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` RetryTimes int `orm:"retry_times" json:"retryTimes"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"` IsOwner *int `orm:"is_owner" json:"isOwner"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` OperatorName string `orm:"operator_name" json:"operatorName"`
IsOwner *int `orm:"is_owner" json:"isOwner"` TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
OperatorName string `orm:"operator_name" json:"operatorName"` ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` FirstFrame string `orm:"first_frame" json:"firstFrame"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` LastFrame string `orm:"last_frame" json:"lastFrame"`
FirstFrame string `orm:"first_frame" json:"firstFrame"` MaxTokens int `orm:"max_tokens" json:"maxTokens"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
} }
const (
ResponseBody = "content" //返回主体(必填)
TotalTokens = "total_tokens" //总token数
)

View File

@@ -99,7 +99,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
ModelName: req.ModelName, ModelName: req.ModelName,
IsChatModel: req.IsChatModel, IsChatModel: req.IsChatModel,
}) })
if err != nil { if err != nil || model == nil {
return nil, err return nil, err
} }
return &dto.GetModelRes{ return &dto.GetModelRes{

View File

@@ -107,13 +107,27 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}, },
}) })
// 5) 获取任务信息 // 5) 抢占任务:改为执行中
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id) rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
State: public.TaskStatusRunning,
})
if err != nil {
return nil, err
}
if rows == 0 {
return nil, fmt.Errorf("任务不存在: id=%d", id)
}
// 6) 查询任务信息
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: id},
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 5) 创建成功后立即异步尝试执行当前任务 // 7) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil

View File

@@ -67,25 +67,39 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
// ============================================ // ============================================
// 2) 调用模型 // 2) 调用模型
// ============================================ // ============================================
switch { for attempt := 0; ; attempt++ {
case model.CallMode != nil && *model.CallMode == public.CallModeStream: if attempt > 0 {
rawBytes, streamErr := w.callModelStream(ctx, task, model, body) g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
if streamErr != nil { time.Sleep(time.Duration(attempt) * time.Second)
w.failTask(ctx, task, startTime, streamErr.Error()) }
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
err = streamErr
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
continue
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err == nil {
break
}
if !strings.Contains(err.Error(), "Timeout") {
w.failTask(ctx, task, startTime, err.Error())
return return
} }
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
result, err = w.callModel(ctx, task, model, body)
}
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
} }
// ============================================ // ============================================
@@ -205,7 +219,7 @@ func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGate
return nil, err return nil, err
} }
// 2. 拿到 task_id // 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String() taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道 // 3. 创建等待通道
ch := make(chan asyncResult, 1) ch := make(chan asyncResult, 1)
@@ -294,6 +308,8 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
var lastErr error
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -302,6 +318,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 1) 响应映射 // 1) 响应映射
mapped, err := util.MapResponsePayload(model.ResponseMapping, body) mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
if err != nil { if err != nil {
lastErr = err
g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][映射失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("响应映射重试耗尽: %w", err) return nil, fmt.Errorf("响应映射重试耗尽: %w", err)
@@ -309,10 +326,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
continue continue
} }
// 2) 存 token 到数据库,防止后续失败丢失 // 2) 存 token
if _, ok := mapped[model.ResponseTokenField]; ok { if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField]) task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens, ExpendTokens: task.ExpendTokens,
}) })
@@ -326,9 +343,9 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
if err == nil { if err == nil {
return parsed, nil return parsed, nil
} }
lastErr = err
case public.BuildTypeStruct: case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody) return util.ParseStructResult(mapped, entity.ResponseBody), nil
return parsed, nil
default: default:
return mapped, nil return mapped, nil
} }
@@ -336,22 +353,22 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err) g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
return nil, fmt.Errorf("JSON解析重试耗尽: %w", err) return nil, fmt.Errorf("JSON解析重试耗尽: %w", lastErr)
} }
// 4) 重新调模型(直接调,不走缓存) // 4) 拼接错误信息到请求体,重调模型
task.RetryCount++ task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task) _, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
body = injectErrorMessage(task.RequestPayload.Body, lastErr)
rawData, callErr := InvokeModel(ctx, model, body)
if callErr != nil { if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue continue
} }
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any var rawResp map[string]any
if err = json.Unmarshal(rawData, &rawResp); err != nil { if err := json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue continue
} }
@@ -361,6 +378,44 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return body, nil return body, nil
} }
// injectErrorMessage 将错误信息拼接到 user 消息中
func injectErrorMessage(payload map[string]any, err error) map[string]any {
if err == nil {
return payload
}
messages, _ := payload["messages"].([]any)
if len(messages) == 0 {
return payload
}
errMsg := fmt.Sprintf("\n\n【上一轮输出错误请修正】%s", err.Error())
// 找到最后一个 role=user 的消息,追加错误提示
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if gconv.String(msg["role"]) != "user" {
continue
}
switch c := msg["content"].(type) {
case string:
msg["content"] = c + errMsg
case []any:
msg["content"] = append(c, map[string]any{
"type": "text",
"text": errMsg,
})
}
break
}
return payload
}
// InvokeModel 调用模型服务,返回二进制结果 // InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {