Files
model-gateway/common/util/streaming.go

151 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package util
import (
"encoding/base64"
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
)
// ================================================================
// ParseStreamResponse 流式响应解析(通用入口)
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
enabled, _ := streamConfig["enabled"].(bool)
if !enabled {
return gjson.New(string(rawBytes)).Map(), nil
}
parser, _ := streamConfig["parser"].(string)
if parser == "base64_concat" {
return parseBase64Stream(rawBytes)
}
return parseSSEStream(rawBytes, streamConfig)
}
// parseBase64Stream 拼接流式 base64 并解码为二进制TTS 等音频模型)
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
lines := strings.Split(string(rawBytes), "\n")
var audioBase64 strings.Builder
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
if data, ok := chunk["data"].(string); ok && data != "" {
audioBase64.WriteString(data)
}
}
cleanBase64 := strings.Map(func(r rune) rune {
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
return -1
}
return r
}, audioBase64.String())
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
if err != nil {
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
if err != nil {
return nil, fmt.Errorf("base64 解码失败: %w", err)
}
}
return map[string]any{"audio": audioBytes}, nil
}
// parseSSEStream SSE 流式解析(图片模型等)
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
events, _ := streamConfig["events"].([]any)
if len(events) == 0 {
return gjson.New(string(rawBytes)).Map(), nil
}
lines := strings.Split(string(rawBytes), "\n")
result := make(map[string]any)
var partials []map[string]any
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || line == "[DONE]" {
continue
}
if strings.HasPrefix(line, "event:") {
continue
}
if strings.HasPrefix(line, "data:") {
line = strings.TrimPrefix(line, "data:")
line = strings.TrimSpace(line)
}
var chunk map[string]any
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
continue
}
chunkType, _ := chunk["type"].(string)
for _, evt := range events {
e, _ := evt.(map[string]any)
match, _ := e["match"].(string)
if !strings.Contains(chunkType, match) {
continue
}
fields, _ := e["fields"].(map[string]any)
aggregateTo, _ := e["aggregate_to"].(string)
evtType, _ := e["type"].(string)
switch evtType {
case "partial":
item := make(map[string]any)
for localKey, chunkKey := range fields {
item[localKey] = chunk[chunkKey.(string)]
}
partials = append(partials, item)
case "final":
for localKey, chunkKey := range fields {
val := gjson.New(chunk).Get(chunkKey.(string))
if !val.IsNil() {
if _, exists := result[aggregateTo]; !exists {
result[aggregateTo] = make(map[string]any)
}
result[aggregateTo].(map[string]any)[localKey] = val.Val()
}
}
}
}
}
if len(partials) > 0 {
for _, evt := range events {
e, _ := evt.(map[string]any)
if e["type"] == "partial" {
if orderBy, ok := e["order_by"].(string); ok {
sort.Slice(partials, func(i, j int) bool {
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
})
}
result[e["aggregate_to"].(string)] = partials
break
}
}
}
mergedBytes, _ := json.Marshal(result)
return gjson.New(mergedBytes).Map(), nil
}