feat(model): 添加流式配置支持并优化响应处理

This commit is contained in:
2026-05-30 22:08:46 +08:00
parent 558fd49ec1
commit c7e9eb889b
5 changed files with 288 additions and 56 deletions

View File

@@ -1,6 +1,7 @@
package util
import (
"encoding/json"
"fmt"
"model-gateway/model/entity"
"net/url"
@@ -9,6 +10,7 @@ import (
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
tgjson "github.com/tidwall/gjson"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
@@ -67,27 +69,40 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
if len(mapping) == 0 {
return responseBytes, nil
return result, nil
}
responseJson := gjson.New(responseBytes)
resultJson := gjson.New("{}")
// 把 result 转成 JSON 字符串tidwall/gjson 需要字符串输入
resultBytes, _ := json.Marshal(result)
resultStr := string(resultBytes)
mapped := make(map[string]any)
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
val := responseJson.Get(path)
if val.IsNil() {
value := tgjson.Get(resultStr, path)
if !value.Exists() {
continue
}
resultJson.Set(standardField, val.Val())
// 如果是数组路径(含 #),取 Array否则取单值
if strings.Contains(path, "#") {
var arr []any
for _, v := range value.Array() {
arr = append(arr, v.Value())
}
mapped[standardField] = arr
} else {
mapped[standardField] = value.Value()
}
}
return []byte(resultJson.String()), nil
return mapped, nil
}
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:

150
common/util/streaming.go Normal file
View File

@@ -0,0 +1,150 @@
package util
import (
"encoding/base64"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
)
// ================================================================
// ParseStreamResponse 流式响应解析(通用入口)
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
enabled, _ := streamConfig["enabled"].(bool)
if !enabled {
return gjson.New(string(rawBytes)).Map(), nil
}
parser, _ := streamConfig["parser"].(string)
if parser == "base64_concat" {
return parseBase64Stream(rawBytes)
}
return parseSSEStream(rawBytes, streamConfig)
}
// parseBase64Stream 拼接流式 base64 并解码为二进制TTS 等音频模型)
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
lines := strings.Split(string(rawBytes), "\n")
var audioBase64 strings.Builder
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
if data, ok := chunk["data"].(string); ok && data != "" {
audioBase64.WriteString(data)
}
}
cleanBase64 := strings.Map(func(r rune) rune {
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
return -1
}
return r
}, audioBase64.String())
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
if err != nil {
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
if err != nil {
return nil, fmt.Errorf("base64 解码失败: %w", err)
}
}
return map[string]any{"audio": audioBytes}, nil
}
// parseSSEStream SSE 流式解析(图片模型等)
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
events, _ := streamConfig["events"].([]any)
if len(events) == 0 {
return gjson.New(string(rawBytes)).Map(), nil
}
lines := strings.Split(string(rawBytes), "\n")
result := make(map[string]any)
var partials []map[string]any
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || line == "[DONE]" {
continue
}
if strings.HasPrefix(line, "event:") {
continue
}
if strings.HasPrefix(line, "data:") {
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
chunkType, _ := chunk["type"].(string)
for _, evt := range events {
e, _ := evt.(map[string]any)
match, _ := e["match"].(string)
if !strings.Contains(chunkType, match) {
continue
}
fields, _ := e["fields"].(map[string]any)
aggregateTo, _ := e["aggregate_to"].(string)
evtType, _ := e["type"].(string)
switch evtType {
case "partial":
item := make(map[string]any)
for localKey, chunkKey := range fields {
item[localKey] = chunk[chunkKey.(string)]
}
partials = append(partials, item)
case "final":
for localKey, chunkKey := range fields {
val := gjson.New(chunk).Get(chunkKey.(string))
if !val.IsNil() {
if _, exists := result[aggregateTo]; !exists {
result[aggregateTo] = make(map[string]any)
}
result[aggregateTo].(map[string]any)[localKey] = val.Val()
}
}
}
}
}
if len(partials) > 0 {
for _, evt := range events {
e, _ := evt.(map[string]any)
if e["type"] == "partial" {
if orderBy, ok := e["order_by"].(string); ok {
sort.Slice(partials, func(i, j int) bool {
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
})
}
result[e["aggregate_to"].(string)] = partials
break
}
}
}
mergedBytes, _ := json.Marshal(result)
return gjson.New(mergedBytes).Map(), nil
}

View File

@@ -32,6 +32,7 @@ type asynchModelCol struct {
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
}
var AsynchModelCol = asynchModelCol{
@@ -64,6 +65,7 @@ var AsynchModelCol = asynchModelCol{
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
}
// AsynchModel 异步模型配置
@@ -97,4 +99,5 @@ type AsynchModel struct {
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
}

View File

@@ -124,14 +124,61 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
return
}
// 4) 调用模型(不重试,失败直接回调)
textResult, err := w.callModel(ctx, t, model, payload)
// 4) 调用模型
var textResult map[string]any
if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled {
rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload)
if modelErr != nil {
w.failTask(ctx, t, modelErr.Error())
return
}
textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
} else {
textResult, err = w.callModel(ctx, t, model, payload)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
}
// 5) 模型返回映射处理
textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult)
if err != nil {
w.failTask(ctx, t, err.Error())
return
}
// 5) 上传 OSS可重试
// 6) 保存临时文件区分二进制音频和JSON文本
if audioData, ok := textResult["audio"].([]byte); ok {
tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
} else {
mappedBytes, _ := json.Marshal(textResult)
if len(mappedBytes) > 0 {
tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json")
if tmpErr == nil && tmpPath != "" {
if t.TmpFile != "" {
_ = os.Remove(t.TmpFile)
}
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
}
}
}
// 7) 上传 OSS可重试
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
@@ -150,35 +197,35 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
}
}
// 6) 解析校验(可重试,失败重新调模型)
//if req.BuildType == 1 {
// for attempt := 0; attempt <= maxRetry; attempt++ {
// if attempt > 0 {
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
// }
// // 6.1) 校验数据
// err = util.ValidatePromptResult(textResult, model)
// if err == nil {
// break
// }
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, err)
// if attempt == maxRetry {
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
// return
// }
// // 6.2) 重新调模型
// newResult, modelErr := w.callModel(ctx, t, model, payload)
// if modelErr != nil {
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, modelErr)
// continue
// }
// textResult = newResult
// }
//}
//8) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
// 6.1) 校验数据
err = util.ValidatePromptResult(textResult, model)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 6.2) 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
}
}
// 7) 成功回调
// 9) 成功回调
t.State = 2
t.OssFile = oss.FileAddressPrefix + oss.FileURL
t.FileType = oss.FileFormat
@@ -199,9 +246,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
// 10) 删除临时文件
_ = os.Remove(t.TmpFile)
}
// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) {
var data []byte
var err error
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "")
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
return data, nil
}
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
@@ -302,14 +380,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// 8响应参数映射
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
if err != nil {
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
return mappedResponse, nil
}
// // InvokeModel 调用模型服务,返回二进制结果

View File

@@ -42,15 +42,8 @@ CREATE TABLE IF NOT EXISTS asynch_models (
remark TEXT DEFAULT '' -- 备注
response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射
operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称
token_config JSONB NOT NULL DEFAULT '{
"zh_ratio": 1.0,
"en_ratio": 1.3,
"space_ratio": 0.1,
"punctuation_ratio": 0.1,
"max_window_size": 8192,
"reserve_ratio": 0.2,
"min_reserve": 512,
}'::jsonb -- Token配置
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb, -- 流式配置
token_config JSONB NOT NULL DEFAULT '{}'::jsonb -- Token配置
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
query_config JSONB NOT NULL DEFAULT '{}'::jsonb;
);