Files
media/service/asr/task_service.go
2026-05-22 17:07:36 +08:00

472 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package asr
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
consts "media/consts/audio"
dao "media/dao/audio"
dto "media/model/dto/audio"
entity "media/model/entity/audio"
serviceScene "media/service/scene"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/guid"
)
var AudioTask = new(audioTaskService)
type audioTaskService struct{}
// CreateTaskParams 创建任务参数
type CreateTaskParams struct {
InputData []string // URL列表
FileNames []string // 文件名列表
Model string
Language string
Threshold float64
CallbackURL string // 任务完成回调地址(完整URL含ip+端口+路径)
}
// Create 创建转写任务并立即返回taskId
func (s *audioTaskService) Create(ctx context.Context, params *CreateTaskParams) (res *dto.CreateTaskRes, err error) {
taskID := "tsk_" + guid.S()
if params.Model == "" {
params.Model = g.Cfg().MustGet(ctx, "whisper.model", "medium").String()
}
if params.Language == "" {
params.Language = g.Cfg().MustGet(ctx, "whisper.language", "zh").String()
}
if params.Threshold <= 0 {
params.Threshold = 0.3
}
inputBytes, _ := json.Marshal(params.InputData)
fnBytes, _ := json.Marshal(params.FileNames)
now := time.Now()
task := &entity.TranscribeTask{
TaskID: taskID,
Status: consts.TaskStatusPending,
Progress: 0,
TotalFiles: len(params.InputData),
InputType: consts.InputTypeURL,
Model: params.Model,
Language: params.Language,
Threshold: params.Threshold,
InputData: string(inputBytes),
FileNames: string(fnBytes),
CallbackURL: params.CallbackURL,
}
task.CreatedAt = gconv.GTime(now)
task.UpdatedAt = gconv.GTime(now)
if _, daoErr := dao.TranscribeTask.Insert(ctx, task); daoErr != nil {
return nil, fmt.Errorf("创建任务失败: %v", daoErr)
}
g.Log().Infof(ctx, "[创建任务 %s] 文件数=%d, 模型=%s, 语言=%s, 回调=%s",
taskID, len(params.InputData), params.Model, params.Language, params.CallbackURL)
// 提取调用方用户信息,传给 goroutine
user := getUserFromCtx(ctx)
// 异步处理
go s.processTask(user, taskID, params.InputData, params.Model, params.Language, params.Threshold, params.CallbackURL)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// processTask 异步处理所有URL每个文件生成一条明细
func (s *audioTaskService) processTask(user *beans.User, taskID string, urls []string, model, language string, threshold float64, callbackURL string) {
ctx := context.Background()
ctx = context.WithValue(ctx, "user", user)
defer func() {
if r := recover(); r != nil {
errMsg := fmt.Sprintf("任务处理异常: %v", r)
g.Log().Errorf(ctx, "[任务 %s] %s, 将通过回调通知调用方", taskID, errMsg)
dao.TranscribeTask.UpdateError(ctx, taskID, errMsg)
g.Log().Infof(ctx, "[任务 %s] 触发失败回调(panic恢复)", taskID)
s.callback(ctx, taskID, consts.TaskStatusFailed, errMsg, callbackURL)
}
}()
g.Log().Infof(ctx, "[任务 %s] 开始处理, 共%d个URL, 回调地址=%s", taskID, len(urls), callbackURL)
dao.TranscribeTask.UpdateTaskRunning(ctx, taskID, 5)
g.Log().Infof(ctx, "[任务 %s] 状态已更新为 running, 进度=5", taskID)
tempDir := getTempDir(ctx)
os.MkdirAll(tempDir, 0755)
var results []dto.TranscribeItem
successCount, failCount := 0, 0
total := len(urls)
for i, videoURL := range urls {
g.Log().Infof(ctx, "[任务 %s] 下载 %d/%d: %s", taskID, i+1, total, videoURL)
progress := 10 + (i*70)/total
dao.TranscribeTask.UpdateProgress(ctx, taskID, progress)
g.Log().Debugf(ctx, "[任务 %s] 进度更新: %d/%d → %d%%", taskID, i+1, total, progress)
savePath, dlErr := downloadFromURL(ctx, videoURL, tempDir)
if dlErr != nil {
g.Log().Warningf(ctx, "[任务 %s] 文件%d/%d 下载失败: %v", taskID, i+1, total, dlErr)
s.saveDetail(ctx, taskID, i, fmt.Sprintf("url_%d.mp4", i+1),
"", "", 0, "", model, language, dlErr.Error())
results = append(results, dto.TranscribeItem{
FileName: fmt.Sprintf("url_%d.mp4", i+1),
Error: dlErr.Error(),
})
failCount++
continue
}
fileName := filepath.Base(savePath)
result := s.processSingleVideo(ctx, taskID, savePath, i, fileName, model, language, threshold)
results = append(results, *result)
if result.Error != "" {
g.Log().Warningf(ctx, "[任务 %s] 文件%d/%d 处理失败: %s - %s", taskID, i+1, total, fileName, result.Error)
failCount++
} else {
g.Log().Infof(ctx, "[任务 %s] 文件%d/%d 处理成功: %s", taskID, i+1, total, fileName)
successCount++
}
}
g.Log().Infof(ctx, "[任务 %s] 所有文件处理完毕, 成功=%d 失败=%d, 开始构建结果JSON", taskID, successCount, failCount)
// 构建完整结果JSON
progress := 95
dao.TranscribeTask.UpdateProgress(ctx, taskID, progress)
g.Log().Debugf(ctx, "[任务 %s] 进度更新: 95%% (结果构建中)", taskID)
resultObj := dto.TaskResult{Results: make([]dto.TaskResultItem, len(results))}
for i, item := range results {
itemDTO := dto.TaskResultItem{
FileName: item.FileName,
Error: item.Error,
}
if item.Result != nil {
if r, ok := item.Result.(*dto.TranscribeResult); ok {
itemDTO.Result = &dto.TaskResultDTO{
Text: r.Text,
Model: r.Model,
Language: r.Language,
AudioSize: r.AudioSize,
AudioDuration: r.AudioDuration,
Scenes: r.Scenes,
}
}
}
resultObj.Results[i] = itemDTO
}
resultJSON, marshalErr := json.Marshal(resultObj)
if marshalErr != nil {
errMsg := "结果序列化失败: " + marshalErr.Error()
g.Log().Errorf(ctx, "[任务 %s] %s", taskID, errMsg)
dao.TranscribeTask.UpdateError(ctx, taskID, errMsg)
s.callback(ctx, taskID, consts.TaskStatusFailed, errMsg, callbackURL)
return
}
resultSize := len(resultJSON)
g.Log().Infof(ctx, "[任务 %s] 结果JSON序列化完成, 大小=%d字节", taskID, resultSize)
if _, err := dao.TranscribeTask.UpdateResult(ctx, taskID, string(resultJSON), successCount, failCount); err != nil {
g.Log().Errorf(ctx, "[任务 %s] 保存结果失败: %v, 通过回调发送结果", taskID, err)
s.callback(ctx, taskID, consts.TaskStatusFailed, fmt.Sprintf("保存结果失败: %v", err), callbackURL)
return
}
g.Log().Infof(ctx, "[任务 %s] 结果已入库, 成功=%d 失败=%d, 触发成功回调", taskID, successCount, failCount)
if callbackURL != "" {
s.callback(ctx, taskID, consts.TaskStatusSuccess, "", callbackURL)
}
g.Log().Infof(ctx, "[任务 %s] 全部处理流程结束", taskID)
}
// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
if callbackURL == "" {
return
}
task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
if task == nil {
g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
return
}
detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
for i := range detailList {
detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
}
// 构建与查询接口一致的 taskInfo
taskInfo := dao.EntityToItem(task)
// 与查询接口一致:从 result 中补全 scenes 等字段
detailItems = enrichDetailsFromResult(task.Result, detailItems)
payload := dto.CallbackPayload{
TaskInfo: taskInfo,
DetailList: detailItems,
}
body, _ := json.Marshal(payload)
g.Log().Infof(ctx, "[回调 %s] 触发回调, 状态=%s, 成功=%d 失败=%d, 错误=%s, 目标=%s",
taskID, taskInfo.Status, taskInfo.SuccessFiles, taskInfo.FailFiles, errMsg, callbackURL)
g.Log().Debugf(ctx, "[回调 %s] 回调载荷长度=%d字节, 明细条数=%d",
taskID, len(body), len(detailItems))
// 透传调用方的用户信息,供回调方 GetUserInfo 从 X-User-Info 头获取
cbUser := getUserFromCtx(ctx)
userJSON, _ := json.Marshal(cbUser)
g.Log().Infof(ctx, "[回调 %s] curl -X POST '%s' -H 'Content-Type: application/json' -H 'X-User-Info: %s' -d '%s'",
taskID, callbackURL, string(userJSON), strings.ReplaceAll(string(body), "'", "'\\''"))
req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-User-Info", string(userJSON))
resp, reqErr := http.DefaultClient.Do(req)
if reqErr != nil {
g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
return
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
}
// processSingleVideo 处理单个文件,同时写入明细表
func (s *audioTaskService) processSingleVideo(ctx context.Context, taskID, savePath string, fileIndex int, fileName, model, language string, threshold float64) *dto.TranscribeItem {
if idx := strings.Index(fileName, "_"); idx > 0 {
fileName = fileName[idx+1:]
}
g.Log().Infof(ctx, "[任务 %s] 开始处理文件(fileIndex=%d): %s", taskID, fileIndex, fileName)
var scenes *dto.SceneSummaryDTO
sceneRes, sceneErr := serviceScene.SceneAnalyzer.Analyze(ctx, &serviceScene.SceneAnalyzeReq{
VideoPaths: []string{savePath},
Threshold: threshold,
ExtractKeyframes: false,
})
if sceneErr != nil {
g.Log().Warningf(ctx, "[任务 %s] 文件 %s 分镜分析失败: %v", taskID, fileName, sceneErr)
} else if len(sceneRes.Analyses) > 0 {
scenes = toSceneDTO(&sceneRes.Analyses[0])
g.Log().Infof(ctx, "[任务 %s] 文件 %s 分镜分析完成, 场景数=%d", taskID, fileName, scenes.TotalScenes)
} else {
g.Log().Infof(ctx, "[任务 %s] 文件 %s 分镜分析无结果", taskID, fileName)
}
g.Log().Infof(ctx, "[任务 %s] 文件 %s 开始语音识别, 模型=%s, 语言=%s", taskID, fileName, model, language)
transRes, transErr := VideoTranscribe.TranscribeVideo(ctx, &VideoTranscribeReq{
VideoPath: savePath,
Model: model,
Language: language,
})
if transErr != nil {
g.Log().Errorf(ctx, "[任务 %s] 文件 %s 语音识别失败: %v", taskID, fileName, transErr)
os.Remove(savePath)
s.saveDetail(ctx, taskID, fileIndex, fileName,
"", "", 0, "", model, language, transErr.Error())
return &dto.TranscribeItem{
FileName: fileName,
Error: transErr.Error(),
}
}
g.Log().Infof(ctx, "[任务 %s] 文件 %s 语音识别成功, 文本长度=%d, 音频时长=%s, 大小=%d",
taskID, fileName, len(transRes.Text), transRes.AudioDuration, transRes.AudioSize)
var scenesJSON string
if scenes != nil {
sb, _ := json.Marshal(scenes)
scenesJSON = string(sb)
}
s.saveDetail(ctx, taskID, fileIndex, fileName,
transRes.Text, scenesJSON, transRes.AudioSize, transRes.AudioDuration, model, language, "")
return &dto.TranscribeItem{
FileName: fileName,
Result: &dto.TranscribeResult{
Text: transRes.Text,
Model: transRes.Model,
Language: transRes.Language,
AudioPath: transRes.AudioPath,
AudioSize: transRes.AudioSize,
AudioDuration: transRes.AudioDuration,
Scenes: scenes,
},
}
}
// getUserFromCtx 从 context 提取用户信息,没有则返回默认 admin
func getUserFromCtx(ctx context.Context) *beans.User {
if u := ctx.Value("user"); u != nil {
if user, ok := u.(*beans.User); ok {
return user
}
}
// 尝试用 common 库解析
user, err := utils.GetUserInfo(ctx)
if err == nil && user != nil {
return user
}
return &beans.User{UserName: "admin", TenantId: 1}
}
// saveDetail 保存单文件明细到 transcribe_task_detail
func (s *audioTaskService) saveDetail(ctx context.Context, taskID string, fileIndex int, fileName, text, scenes string, audioSize int64, audioDuration, model, language, errMsg string) {
detail := &entity.TranscribeTaskDetail{
TaskID: taskID,
FileIndex: fileIndex,
FileName: fileName,
TranscribedText: text,
Scenes: scenes,
AudioSize: audioSize,
AudioDuration: audioDuration,
Model: model,
Language: language,
ErrorMessage: errMsg,
}
if _, daoErr := dao.TranscribeTaskDetail.Insert(ctx, detail); daoErr != nil {
g.Log().Errorf(ctx, "[任务 %s] 保存明细失败(fileIndex=%d): %v", taskID, fileIndex, daoErr)
} else {
g.Log().Debugf(ctx, "[任务 %s] 明细已保存(fileIndex=%d, fileName=%s, 有错误=%v)",
taskID, fileIndex, fileName, errMsg != "")
}
}
// ---------- 查询任务 ----------
func (s *audioTaskService) GetTask(ctx context.Context, req *dto.GetTaskReq) (res *dto.GetTaskRes, err error) {
if req.TaskID == "" {
return nil, fmt.Errorf("taskId不能为空")
}
task, err := dao.TranscribeTask.GetByTaskID(ctx, req.TaskID)
if err != nil {
return nil, fmt.Errorf("查询任务失败: %v", err)
}
if task == nil {
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
}
detailList, err := dao.TranscribeTaskDetail.ListByTaskID(ctx, req.TaskID)
if err != nil {
g.Log().Warningf(ctx, "[任务 %s] 查询明细失败: %v", req.TaskID, err)
}
g.Log().Infof(ctx, "[查询任务] taskId=%s, 状态=%s, 进度=%d", req.TaskID, task.Status, task.Progress)
item := dao.EntityToItem(task)
detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
for i := range detailList {
detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
}
// 兼容历史数据: 若 detail.scenes 为空但有 result JSON, 从 result 中提取 scenes 补上
detailItems = enrichDetailsFromResult(task.Result, detailItems)
return &dto.GetTaskRes{
TaskInfo: item,
DetailList: detailItems,
}, nil
}
// enrichDetailsFromResult 从 result JSON 中补全明细中的 scenes 等字段
func enrichDetailsFromResult(resultJSON string, details []dto.TranscribeTaskDetailItem) []dto.TranscribeTaskDetailItem {
if resultJSON == "" || len(details) == 0 {
return details
}
var taskResult dto.TaskResult
if err := json.Unmarshal([]byte(resultJSON), &taskResult); err != nil {
return details
}
for i, d := range details {
if d.Scenes != "" {
continue // 已有 scenes不需要补
}
for _, r := range taskResult.Results {
if r.Result == nil || r.Result.Scenes == nil {
continue
}
if r.FileName == d.FileName {
sb, _ := json.Marshal(r.Result.Scenes)
details[i].Scenes = string(sb)
// 同时补全其他可能缺失的字段
if d.AudioDuration == "" {
details[i].AudioDuration = r.Result.AudioDuration
}
if d.AudioSize == 0 {
details[i].AudioSize = r.Result.AudioSize
}
if d.Model == "" {
details[i].Model = r.Result.Model
}
if d.Language == "" {
details[i].Language = r.Result.Language
}
break
}
}
}
return details
}
func (s *audioTaskService) GetProgress(ctx context.Context, req *dto.GetProgressReq) (res *dto.GetProgressRes, err error) {
if req.TaskID == "" {
return nil, fmt.Errorf("taskId不能为空")
}
task, err := dao.TranscribeTask.GetByTaskID(ctx, req.TaskID)
if err != nil {
return nil, fmt.Errorf("查询任务失败: %v", err)
}
if task == nil {
return nil, fmt.Errorf("任务不存在: %s", req.TaskID)
}
p := dao.EntityToProgress(task)
g.Log().Infof(ctx, "[查询进度] taskId=%s, 状态=%s, 进度=%d", req.TaskID, p.Status, p.Progress)
return &p, nil
}
func (s *audioTaskService) ListTasks(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
// 确保分页参数不为 nil
if req.Page == nil {
req.Page = &beans.Page{PageNum: 1, PageSize: 10}
}
list, total, err := dao.TranscribeTask.List(ctx, req)
if err != nil {
return nil, fmt.Errorf("查询任务列表失败: %v", err)
}
items := make([]dto.TranscribeTaskItem, len(list))
for i, task := range list {
items[i] = dao.EntityToItem(&task)
}
g.Log().Infof(ctx, "[查询列表] status=%s, pageNum=%d, pageSize=%d, 命中=%d/总量=%d",
req.Status, req.Page.PageNum, req.Page.PageSize, len(items), total)
return &dto.ListTaskRes{List: items, Total: total}, nil
}