feat(model): 添加流式配置支持并优化响应处理

This commit is contained in:
2026-05-30 22:08:46 +08:00
parent 558fd49ec1
commit c7e9eb889b
5 changed files with 288 additions and 56 deletions

View File

@@ -124,14 +124,61 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
return
}
// 4) 调用模型(不重试,失败直接回调)
textResult, err := w.callModel(ctx, t, model, payload)
// 4) 调用模型
var textResult map[string]any
if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled {
rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload)
if modelErr != nil {
w.failTask(ctx, t, modelErr.Error())
return
}
textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
} else {
textResult, err = w.callModel(ctx, t, model, payload)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
}
// 5) 模型返回映射处理
textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
// 5) 上传 OSS可重试
// 6) 保存临时文件区分二进制音频和JSON文本
if audioData, ok := textResult["audio"].([]byte); ok {
tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
} else {
mappedBytes, _ := json.Marshal(textResult)
if len(mappedBytes) > 0 {
tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
}
// 7) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
@@ -150,35 +197,35 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
}
}
// 6) 解析校验(可重试,失败重新调模型)
//if req.BuildType == 1 {
// for attempt := 0; attempt <= maxRetry; attempt++ {
// if attempt > 0 {
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
// }
// // 6.1) 校验数据
// err = util.ValidatePromptResult(textResult, model)
// if err == nil {
// break
// }
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, err)
// if attempt == maxRetry {
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
// return
// }
// // 6.2) 重新调模型
// newResult, modelErr := w.callModel(ctx, t, model, payload)
// if modelErr != nil {
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, modelErr)
// continue
// }
// textResult = newResult
// }
//}
//8) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
// 6.1) 校验数据
err = util.ValidatePromptResult(textResult, model)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 6.2) 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
}
}
// 7) 成功回调
// 9) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
@@ -199,9 +246,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
// 10) 删除临时文件
_ = os.Remove(t.TmpFile)
}
// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) {
var data []byte
var err error
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "")
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
return data, nil
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
@@ -302,14 +380,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// 8响应参数映射
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
if err != nil {
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
return mappedResponse, nil
return b, nil
}
// // InvokeModel 调用模型服务,返回二进制结果