Files
ai-agent/digital-human/service/async_task_service.go
2026-04-27 14:02:43 +08:00

196 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"fmt"
"ai-agent/digital-human/consts"
"ai-agent/digital-human/consts/public"
"ai-agent/digital-human/dao"
"ai-agent/digital-human/model/dto"
"ai-agent/digital-human/model/entity"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
type asyncTaskService struct{}
// AsyncTask 异步任务同步服务(供定时任务/业务轮询调用)
var AsyncTask = new(asyncTaskService)
// Sync
// 1) 扫描 digital_human_async_task_ref 中 state=0/1 的记录(业务“生成中”)
// 2) 组装 task_id 批量请求 model-asynch /task/get-task-batch
// 3) 中间件状态映射到业务状态业务只维护三态0生成中/1成功/2失败
// - 中间件 0/1/3能查到 task_id -> 业务 0生成中
// - 中间件 2/4成功/已下载) -> 业务 1成功
// - 中间件 查不到 task_id返回列表缺失 -> 业务 2失败
//
// 4) 绑定表仅用于“待同步列表”,因此:
// - 对中间件 0/1/3 不额外写库(减少查询/更新开销)
// - 对成功(2/4)与缺失(task_id 查不到)才更新绑定表
func (s *asyncTaskService) Sync(ctx context.Context, req *dto.SyncAsyncTasksReq) (res *dto.SyncAsyncTasksRes, err error) {
limit := 200
if req != nil && req.Limit > 0 {
limit = req.Limit
}
refs, err := dao.AsyncTaskRef.ListPending(ctx, limit)
if err != nil {
return nil, err
}
taskIDs := make([]string, 0, len(refs))
refMap := make(map[string]*entity.AsyncTaskRef, len(refs))
for _, r := range refs {
if r == nil || r.TaskID == "" {
continue
}
taskIDs = append(taskIDs, r.TaskID)
refMap[r.TaskID] = r
}
out := &dto.SyncAsyncTasksRes{
Total: len(taskIDs),
List: make([]dto.SyncAsyncTasksItem, 0, len(taskIDs)),
}
if len(taskIDs) == 0 {
return out, nil
}
items, err := getModelAsynchTaskBatch(ctx, taskIDs)
if err != nil {
return nil, err
}
seen := make(map[string]struct{}, len(items))
handled := 0
for _, it := range items {
r := refMap[it.TaskID]
if r == nil {
continue
}
seen[it.TaskID] = struct{}{}
switch it.State {
case 0, 1, 3:
// 排队中/执行中/失败(可能重试):业务侧仍视为生成中,不更新绑定表,减少更新开销
case 2, 4:
// 成功/已下载:业务侧写入 oss_file 并标记成功
if it.OssFile == "" {
errMsg := "中间件返回空oss地址"
_ = s.updateBizFailed(ctx, r, errMsg)
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
entity.AsyncTaskRefCol.State: it.State,
entity.AsyncTaskRefCol.OssFile: "",
entity.AsyncTaskRefCol.ErrorMsg: errMsg,
})
out.List = append(out.List, dto.SyncAsyncTasksItem{
TaskID: it.TaskID,
State: it.State,
TableName: r.TableName,
BizID: fmt.Sprintf("%d", r.BizID),
OssFile: "",
ErrorMsg: errMsg,
})
continue
}
if err := s.updateBizSuccess(ctx, r, it.OssFile); err != nil {
errMsg := fmt.Sprintf("生成音频失败: %v", err)
_ = s.updateBizFailed(ctx, r, errMsg)
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
entity.AsyncTaskRefCol.State: it.State,
entity.AsyncTaskRefCol.OssFile: it.OssFile,
entity.AsyncTaskRefCol.ErrorMsg: errMsg,
})
out.List = append(out.List, dto.SyncAsyncTasksItem{
TaskID: it.TaskID,
State: it.State,
TableName: r.TableName,
BizID: fmt.Sprintf("%d", r.BizID),
OssFile: it.OssFile,
ErrorMsg: errMsg,
})
continue
}
handled++
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
entity.AsyncTaskRefCol.State: it.State,
entity.AsyncTaskRefCol.OssFile: it.OssFile,
entity.AsyncTaskRefCol.ErrorMsg: "",
})
default:
// 其他状态:不处理
}
out.List = append(out.List, dto.SyncAsyncTasksItem{
TaskID: it.TaskID,
State: it.State,
TableName: r.TableName,
BizID: fmt.Sprintf("%d", r.BizID),
OssFile: it.OssFile,
ErrorMsg: "",
})
}
// 处理“查不到 task_id”的情况
// 中间件对失败重试耗尽的任务会硬删除,批量接口不会返回该 task_id。
// 业务侧把这种情况视为失败终态,并软删除绑定记录,避免重复轮询。
for _, taskID := range taskIDs {
if _, ok := seen[taskID]; ok {
continue
}
r := refMap[taskID]
if r == nil {
continue
}
msg := "模型任务不存在已失败"
_ = s.updateBizFailed(ctx, r, msg)
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, taskID, gdb.Map{
entity.AsyncTaskRefCol.State: 3,
entity.AsyncTaskRefCol.ErrorMsg: msg,
"deleted_at": gtime.Now(),
})
out.List = append(out.List, dto.SyncAsyncTasksItem{
TaskID: taskID,
State: 3,
TableName: r.TableName,
BizID: fmt.Sprintf("%d", r.BizID),
OssFile: "",
ErrorMsg: msg,
})
}
out.Handled = handled
g.Log().Infof(ctx, "[AsyncTask.Sync] total=%d handled=%d", out.Total, out.Handled)
return out, nil
}
// updateBizSuccess 更新业务侧状态为成功
func (s *asyncTaskService) updateBizSuccess(ctx context.Context, ref *entity.AsyncTaskRef, ossFile string) error {
switch ref.TableName {
case public.TableNameAudio:
_, err := dao.Audio.UpdateStatus(ctx, ref.BizID, consts.AudioStatusSuccess, "", ossFile, 0, "")
return err
case public.TableNameCustomVoice:
_, err := dao.CustomVoice.UpdateStatus(ctx, ref.BizID, 1, "", ossFile)
return err
default:
return fmt.Errorf("未知 table_name=%s", ref.TableName)
}
}
// updateBizFailed 更新业务侧状态为失败
func (s *asyncTaskService) updateBizFailed(ctx context.Context, ref *entity.AsyncTaskRef, msg string) error {
switch ref.TableName {
case public.TableNameAudio:
_, err := dao.Audio.UpdateStatus(ctx, ref.BizID, consts.AudioStatusFailed, msg, "", 0, "")
return err
case public.TableNameCustomVoice:
_, err := dao.CustomVoice.UpdateStatus(ctx, ref.BizID, 2, msg, "")
return err
default:
return nil
}
}