From c7e9eb889b90f8be5de8aded63f9f255fc1d3075 Mon Sep 17 00:00:00 2001 From: WangLiZhao <1838393649@qq.com> Date: Sat, 30 May 2026 22:08:46 +0800 Subject: [PATCH] =?UTF-8?q?feat(model):=20=E6=B7=BB=E5=8A=A0=E6=B5=81?= =?UTF-8?q?=E5=BC=8F=E9=85=8D=E7=BD=AE=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=93=8D=E5=BA=94=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/util/mapping.go | 31 ++++++-- common/util/streaming.go | 150 +++++++++++++++++++++++++++++++++++ model/entity/asynch_model.go | 3 + service/task/worker.go | 149 +++++++++++++++++++++++++--------- update.sql | 11 +-- 5 files changed, 288 insertions(+), 56 deletions(-) create mode 100644 common/util/streaming.go diff --git a/common/util/mapping.go b/common/util/mapping.go index 01b0cbe..f7c344b 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -1,6 +1,7 @@ package util import ( + "encoding/json" "fmt" "model-gateway/model/entity" "net/url" @@ -9,6 +10,7 @@ import ( "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" + tgjson "github.com/tidwall/gjson" ) // ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 @@ -67,27 +69,40 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { } // MapResponsePayload 映射模型响应为标准格式 -func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) { +func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) { if len(mapping) == 0 { - return responseBytes, nil + return result, nil } - responseJson := gjson.New(responseBytes) - resultJson := gjson.New("{}") + // 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入 + resultBytes, _ := json.Marshal(result) + resultStr := string(resultBytes) + + mapped := make(map[string]any) for standardField, modelPath := range mapping { path := gconv.String(modelPath) if path == "" { continue } - val := responseJson.Get(path) - if val.IsNil() { + + value := tgjson.Get(resultStr, path) + if !value.Exists() { continue } - resultJson.Set(standardField, val.Val()) + // 如果是数组路径(含 #),取 Array;否则取单值 + if strings.Contains(path, "#") { + var arr []any + for _, v := range value.Array() { + arr = append(arr, v.Value()) + } + mapped[standardField] = arr + } else { + mapped[standardField] = value.Value() + } } - return []byte(resultJson.String()), nil + return mapped, nil } // ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: diff --git a/common/util/streaming.go b/common/util/streaming.go new file mode 100644 index 0000000..7cb3707 --- /dev/null +++ b/common/util/streaming.go @@ -0,0 +1,150 @@ +package util + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/gogf/gf/v2/encoding/gjson" +) + +// ================================================================ + +// ParseStreamResponse 流式响应解析(通用入口) +func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) { + enabled, _ := streamConfig["enabled"].(bool) + if !enabled { + return gjson.New(string(rawBytes)).Map(), nil + } + + parser, _ := streamConfig["parser"].(string) + if parser == "base64_concat" { + return parseBase64Stream(rawBytes) + } + + return parseSSEStream(rawBytes, streamConfig) +} + +// parseBase64Stream 拼接流式 base64 并解码为二进制(TTS 等音频模型) +func parseBase64Stream(rawBytes []byte) (map[string]any, error) { + lines := strings.Split(string(rawBytes), "\n") + var audioBase64 strings.Builder + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + var chunk map[string]any + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + continue + } + + if data, ok := chunk["data"].(string); ok && data != "" { + audioBase64.WriteString(data) + } + } + + cleanBase64 := strings.Map(func(r rune) rune { + if r == ' ' || r == '\n' || r == '\r' || r == '\t' { + return -1 + } + return r + }, audioBase64.String()) + + audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64) + if err != nil { + audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64) + if err != nil { + return nil, fmt.Errorf("base64 解码失败: %w", err) + } + } + + return map[string]any{"audio": audioBytes}, nil +} + +// parseSSEStream SSE 流式解析(图片模型等) +func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) { + events, _ := streamConfig["events"].([]any) + if len(events) == 0 { + return gjson.New(string(rawBytes)).Map(), nil + } + + lines := strings.Split(string(rawBytes), "\n") + result := make(map[string]any) + var partials []map[string]any + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || line == "[DONE]" { + continue + } + if strings.HasPrefix(line, "event:") { + continue + } + if strings.HasPrefix(line, "data:") { + line = strings.TrimPrefix(line, "data:") + line = strings.TrimSpace(line) + } + + var chunk map[string]any + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + continue + } + + chunkType, _ := chunk["type"].(string) + + for _, evt := range events { + e, _ := evt.(map[string]any) + match, _ := e["match"].(string) + if !strings.Contains(chunkType, match) { + continue + } + + fields, _ := e["fields"].(map[string]any) + aggregateTo, _ := e["aggregate_to"].(string) + evtType, _ := e["type"].(string) + + switch evtType { + case "partial": + item := make(map[string]any) + for localKey, chunkKey := range fields { + item[localKey] = chunk[chunkKey.(string)] + } + partials = append(partials, item) + + case "final": + for localKey, chunkKey := range fields { + val := gjson.New(chunk).Get(chunkKey.(string)) + if !val.IsNil() { + if _, exists := result[aggregateTo]; !exists { + result[aggregateTo] = make(map[string]any) + } + result[aggregateTo].(map[string]any)[localKey] = val.Val() + } + } + } + } + } + + if len(partials) > 0 { + for _, evt := range events { + e, _ := evt.(map[string]any) + if e["type"] == "partial" { + if orderBy, ok := e["order_by"].(string); ok { + sort.Slice(partials, func(i, j int) bool { + return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy]) + }) + } + result[e["aggregate_to"].(string)] = partials + break + } + } + } + + mergedBytes, _ := json.Marshal(result) + return gjson.New(mergedBytes).Map(), nil +} diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index 7abc09c..505781d 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -32,6 +32,7 @@ type asynchModelCol struct { TokenConfig string ExtendMapping string QueryConfig string + StreamConfig string } var AsynchModelCol = asynchModelCol{ @@ -64,6 +65,7 @@ var AsynchModelCol = asynchModelCol{ TokenConfig: "token_config", ExtendMapping: "extend_mapping", QueryConfig: "query_config", + StreamConfig: "stream_config", } // AsynchModel 异步模型配置 @@ -97,4 +99,5 @@ type AsynchModel struct { TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` + StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"` } diff --git a/service/task/worker.go b/service/task/worker.go index 1a2b03e..0edeb98 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -124,14 +124,61 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * return } - // 4) 调用模型(不重试,失败直接回调) - textResult, err := w.callModel(ctx, t, model, payload) + // 4) 调用模型 + var textResult map[string]any + if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled { + rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload) + if modelErr != nil { + w.failTask(ctx, t, modelErr.Error()) + return + } + textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) + if err != nil { + w.failTask(ctx, t, err.Error()) + return + } + } else { + textResult, err = w.callModel(ctx, t, model, payload) + if err != nil { + w.failTask(ctx, t, err.Error()) + return + } + } + + // 5) 模型返回映射处理 + textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult) if err != nil { w.failTask(ctx, t, err.Error()) return } - // 5) 上传 OSS(可重试) + // 6) 保存临时文件(区分二进制音频和JSON文本) + if audioData, ok := textResult["audio"].([]byte); ok { + tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3") + if tmpErr == nil && tmpPath != "" { + if t.TmpFile != "" { + _ = os.Remove(t.TmpFile) + } + t.TmpFile = tmpPath + t.Phase = 1 + _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) + } + } else { + mappedBytes, _ := json.Marshal(textResult) + if len(mappedBytes) > 0 { + tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json") + if tmpErr == nil && tmpPath != "" { + if t.TmpFile != "" { + _ = os.Remove(t.TmpFile) + } + t.TmpFile = tmpPath + t.Phase = 1 + _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) + } + } + } + + // 7) 上传 OSS(可重试) var oss *gateway.UploadFileResponse for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { @@ -150,35 +197,35 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * } } - // 6) 解析校验(可重试,失败重新调模型) - //if req.BuildType == 1 { - // for attempt := 0; attempt <= maxRetry; attempt++ { - // if attempt > 0 { - // g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) - // } - // // 6.1) 校验数据 - // err = util.ValidatePromptResult(textResult, model) - // if err == nil { - // break - // } - // g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", - // t.TaskID, attempt, maxRetry, err) - // if attempt == maxRetry { - // w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) - // return - // } - // // 6.2) 重新调模型 - // newResult, modelErr := w.callModel(ctx, t, model, payload) - // if modelErr != nil { - // g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", - // t.TaskID, attempt, maxRetry, modelErr) - // continue - // } - // textResult = newResult - // } - //} + //8) 解析校验(可重试,失败重新调模型) + if req.BuildType == 1 { + for attempt := 0; attempt <= maxRetry; attempt++ { + if attempt > 0 { + g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) + } + // 6.1) 校验数据 + err = util.ValidatePromptResult(textResult, model) + if err == nil { + break + } + g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", + t.TaskID, attempt, maxRetry, err) + if attempt == maxRetry { + w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) + return + } + // 6.2) 重新调模型 + newResult, modelErr := w.callModel(ctx, t, model, payload) + if modelErr != nil { + g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", + t.TaskID, attempt, maxRetry, modelErr) + continue + } + textResult = newResult + } + } - // 7) 成功回调 + // 9) 成功回调 t.State = 2 t.OssFile = oss.FileAddressPrefix + oss.FileURL t.FileType = oss.FileFormat @@ -199,9 +246,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s", t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL) + + // 10) 删除临时文件 _ = os.Remove(t.TmpFile) } +// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出) +func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) { + var data []byte + var err error + + if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { + data, err = os.ReadFile(task.TmpFile) + if err != nil || len(data) == 0 { + data = nil + } + } + + if data == nil { + _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) + data, err = InvokeModel(ctx, model, payload, task.ModelKey) + if err != nil { + return nil, err + } + tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "") + if tmpErr == nil && tmpPath != "" { + task.TmpFile = tmpPath + task.Phase = 1 + _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) + } + } + + return data, nil +} + // 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) // callModel 调用模型 + 检测文件类型 + 保存临时文件 func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) { @@ -302,14 +380,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str msg := string(b) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } - - // 8)响应参数映射 - mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b) - if err != nil { - g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) - return b, nil - } - return mappedResponse, nil + return b, nil } // // InvokeModel 调用模型服务,返回二进制结果 diff --git a/update.sql b/update.sql index d83b2de..672a85d 100644 --- a/update.sql +++ b/update.sql @@ -42,15 +42,8 @@ CREATE TABLE IF NOT EXISTS asynch_models ( remark TEXT DEFAULT '' -- 备注 response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射 operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称 - token_config JSONB NOT NULL DEFAULT '{ - "zh_ratio": 1.0, - "en_ratio": 1.3, - "space_ratio": 0.1, - "punctuation_ratio": 0.1, - "max_window_size": 8192, - "reserve_ratio": 0.2, - "min_reserve": 512, -}'::jsonb -- Token配置 + stream_config JSONB NOT NULL DEFAULT '{}'::jsonb, -- 流式配置 + token_config JSONB NOT NULL DEFAULT '{}'::jsonb -- Token配置 extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, query_config JSONB NOT NULL DEFAULT '{}'::jsonb; );