Files
ai-agent/digitalhuman/cmd/main.go
2026-04-27 11:07:21 +08:00

196 lines
5.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
var text = "欢迎使用红动未来数字人服务平台我们将为您提供最优质的AI数字人解决方案。"
type TTSCommonResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Text string `json:"text"`
Audio string `json:"audio"`
}
func main() {
// 获取当前工作目录
outputDir, err := os.Getwd()
if err != nil {
fmt.Printf("获取当前目录失败: %v\n", err)
os.Exit(1)
}
// 查找项目根目录(向上查找包含 go.mod 的目录)
outputDir = findProjectRoot(outputDir)
// 验证根目录是否正确(检查是否有 go.mod
if _, err := os.Stat(outputDir + "/go.mod"); err != nil {
fmt.Printf("未找到项目根目录,当前目录: %s\n", outputDir)
os.Exit(1)
}
fmt.Println("=================== TTS测试开始 ===================")
fmt.Printf("输出目录: %s\n", outputDir)
fmt.Printf("随机文本: %s\n", text)
fmt.Printf("请求URL: http://127.0.0.1:8000/tts\n")
// 创建带超时的 HTTP 客户端120秒超时
client := &http.Client{
Timeout: 120 * time.Second,
}
resp, err := client.Post("http://127.0.0.1:8000/tts", "application/json", bytes.NewBufferString(fmt.Sprintf(`"%s"`, text)))
if err != nil {
fmt.Printf("请求失败: %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
// 打印响应头
fmt.Printf("Content-Type: %s\n", resp.Header.Get("Content-Type"))
fmt.Printf("Content-Length: %s\n", resp.Header.Get("Content-Length"))
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("读取响应失败: %v\n", err)
os.Exit(1)
}
fmt.Printf("状态码: %d, 响应大小: %d字节\n", resp.StatusCode, len(body))
// 打印响应内容的前200字节用于调试
if len(body) > 0 {
previewLen := minInt(200, len(body))
fmt.Printf("响应内容预览(前%d字节): ", previewLen)
if len(body) >= 4 && string(body[:4]) == "RIFF" {
// WAV文件头
fmt.Printf("WAV文件格式 (RIFF...)\n")
} else if len(body) >= 3 && string(body[:3]) == "ID3" {
// MP3 ID3标签
fmt.Printf("MP3 ID3格式\n")
} else if len(body) >= 2 && body[0] == 0xFF && (body[1]&0xE0) == 0xE0 {
// MP3帧同步
fmt.Printf("MP3帧格式\n")
} else {
// 可能是JSON或其他格式
fmt.Printf("%s\n", string(body[:previewLen]))
}
} else {
fmt.Printf("响应内容为空!\n")
os.Exit(1)
}
// 尝试解析JSON响应包含base64音频
var commonResp TTSCommonResponse
var audioData []byte
var ext string
if json.Unmarshal(body, &commonResp) == nil && commonResp.Audio != "" && commonResp.Audio != "base64_placeholder" {
fmt.Printf("检测到JSON响应code=%d, msg=%s\n", commonResp.Code, commonResp.Msg)
fmt.Printf("Audio字段长度: %d 字符\n", len(commonResp.Audio))
// 检查是否成功
if commonResp.Code != 0 {
fmt.Printf("TTS服务返回错误: %s\n", commonResp.Msg)
os.Exit(1)
}
// 解码base64音频数据
decoded, err := base64.StdEncoding.DecodeString(commonResp.Audio)
if err != nil {
fmt.Printf("base64解码失败: %v\n", err)
os.Exit(1)
}
if len(decoded) == 0 {
fmt.Printf("解码后数据为空!\n")
os.Exit(1)
}
audioData = decoded
fmt.Printf("解码后音频数据大小: %d 字节\n", len(audioData))
// 根据解码后的音频数据格式决定扩展名
if len(audioData) >= 4 && string(audioData[:4]) == "RIFF" {
ext = ".wav"
fmt.Printf("检测到WAV格式\n")
} else if len(audioData) >= 3 && string(audioData[:3]) == "ID3" || (len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xE0) == 0xE0) {
ext = ".mp3"
fmt.Printf("检测到MP3格式\n")
} else {
ext = ".wav" // 默认wav
fmt.Printf("未知格式,默认保存为 .wav\n")
}
} else {
// 直接是二进制音频数据
audioData = body
// 根据音频数据格式决定扩展名
if len(audioData) >= 4 && string(audioData[:4]) == "RIFF" {
ext = ".wav"
} else if len(audioData) >= 3 && string(audioData[:3]) == "ID3" || (len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xE0) == 0xE0) {
ext = ".mp3"
} else {
ext = ".wav" // 默认wav
}
}
// 保存音频文件
filename := fmt.Sprintf("%s/tts_output_%d%s", outputDir, time.Now().Unix(), ext)
if err = os.WriteFile(filename, audioData, 0644); err != nil {
fmt.Printf("写文件失败: %v\n", err)
os.Exit(1)
}
fmt.Printf("音频已保存: %s (%d字节)\n", filename, len(audioData))
fmt.Println("=================== TTS测试成功 ===================")
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
// findProjectRoot 查找项目根目录(包含 go.mod 的目录)
func findProjectRoot(startDir string) string {
dir := startDir
for {
// 检查当前目录是否有 go.mod
if _, err := os.Stat(dir + "/go.mod"); err == nil {
return dir
}
// 如果已经是根目录或无法继续向上查找,返回当前目录
parentDir := dir[:maxInt(0, len(dir)-len("/"+getLastPathSegment(dir)))]
if parentDir == dir || parentDir == "" {
return startDir
}
dir = parentDir
}
}
// getLastPathSegment 获取路径的最后一部分
func getLastPathSegment(path string) string {
if idx := strings.LastIndex(path, "/"); idx != -1 {
return path[idx+1:]
}
return path
}