151 lines
3.7 KiB
Go
151 lines
3.7 KiB
Go
|
|
package util
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"encoding/base64"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"sort"
|
|||
|
|
"strings"
|
|||
|
|
|
|||
|
|
"github.com/gogf/gf/v2/encoding/gjson"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ================================================================
|
|||
|
|
|
|||
|
|
// ParseStreamResponse 流式响应解析(通用入口)
|
|||
|
|
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
|||
|
|
enabled, _ := streamConfig["enabled"].(bool)
|
|||
|
|
if !enabled {
|
|||
|
|
return gjson.New(string(rawBytes)).Map(), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
parser, _ := streamConfig["parser"].(string)
|
|||
|
|
if parser == "base64_concat" {
|
|||
|
|
return parseBase64Stream(rawBytes)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return parseSSEStream(rawBytes, streamConfig)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// parseBase64Stream 拼接流式 base64 并解码为二进制(TTS 等音频模型)
|
|||
|
|
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
|
|||
|
|
lines := strings.Split(string(rawBytes), "\n")
|
|||
|
|
var audioBase64 strings.Builder
|
|||
|
|
|
|||
|
|
for _, line := range lines {
|
|||
|
|
line = strings.TrimSpace(line)
|
|||
|
|
if line == "" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var chunk map[string]any
|
|||
|
|
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if data, ok := chunk["data"].(string); ok && data != "" {
|
|||
|
|
audioBase64.WriteString(data)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cleanBase64 := strings.Map(func(r rune) rune {
|
|||
|
|
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
|
|||
|
|
return -1
|
|||
|
|
}
|
|||
|
|
return r
|
|||
|
|
}, audioBase64.String())
|
|||
|
|
|
|||
|
|
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
|
|||
|
|
if err != nil {
|
|||
|
|
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("base64 解码失败: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return map[string]any{"audio": audioBytes}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// parseSSEStream SSE 流式解析(图片模型等)
|
|||
|
|
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
|||
|
|
events, _ := streamConfig["events"].([]any)
|
|||
|
|
if len(events) == 0 {
|
|||
|
|
return gjson.New(string(rawBytes)).Map(), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
lines := strings.Split(string(rawBytes), "\n")
|
|||
|
|
result := make(map[string]any)
|
|||
|
|
var partials []map[string]any
|
|||
|
|
|
|||
|
|
for _, line := range lines {
|
|||
|
|
line = strings.TrimSpace(line)
|
|||
|
|
if line == "" || line == "[DONE]" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if strings.HasPrefix(line, "event:") {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
if strings.HasPrefix(line, "data:") {
|
|||
|
|
line = strings.TrimPrefix(line, "data:")
|
|||
|
|
line = strings.TrimSpace(line)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var chunk map[string]any
|
|||
|
|
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
chunkType, _ := chunk["type"].(string)
|
|||
|
|
|
|||
|
|
for _, evt := range events {
|
|||
|
|
e, _ := evt.(map[string]any)
|
|||
|
|
match, _ := e["match"].(string)
|
|||
|
|
if !strings.Contains(chunkType, match) {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fields, _ := e["fields"].(map[string]any)
|
|||
|
|
aggregateTo, _ := e["aggregate_to"].(string)
|
|||
|
|
evtType, _ := e["type"].(string)
|
|||
|
|
|
|||
|
|
switch evtType {
|
|||
|
|
case "partial":
|
|||
|
|
item := make(map[string]any)
|
|||
|
|
for localKey, chunkKey := range fields {
|
|||
|
|
item[localKey] = chunk[chunkKey.(string)]
|
|||
|
|
}
|
|||
|
|
partials = append(partials, item)
|
|||
|
|
|
|||
|
|
case "final":
|
|||
|
|
for localKey, chunkKey := range fields {
|
|||
|
|
val := gjson.New(chunk).Get(chunkKey.(string))
|
|||
|
|
if !val.IsNil() {
|
|||
|
|
if _, exists := result[aggregateTo]; !exists {
|
|||
|
|
result[aggregateTo] = make(map[string]any)
|
|||
|
|
}
|
|||
|
|
result[aggregateTo].(map[string]any)[localKey] = val.Val()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(partials) > 0 {
|
|||
|
|
for _, evt := range events {
|
|||
|
|
e, _ := evt.(map[string]any)
|
|||
|
|
if e["type"] == "partial" {
|
|||
|
|
if orderBy, ok := e["order_by"].(string); ok {
|
|||
|
|
sort.Slice(partials, func(i, j int) bool {
|
|||
|
|
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
result[e["aggregate_to"].(string)] = partials
|
|||
|
|
break
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
mergedBytes, _ := json.Marshal(result)
|
|||
|
|
return gjson.New(mergedBytes).Map(), nil
|
|||
|
|
}
|