refactor(task): 重构任务服务和数据结构

This commit is contained in:
2026-06-12 15:29:05 +08:00
parent 1c6c9bae14
commit b3b111995e
7 changed files with 227 additions and 292 deletions

View File

@@ -27,10 +27,7 @@ type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
var (
startAt = time.Now()
taskID = uuid.NewString()
)
taskID := uuid.NewString()
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
@@ -64,10 +61,6 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}
// 3) 插入任务记录
if model.CallMode != nil && *model.CallMode == public.CallModeAsync {
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
requestPayload := entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
@@ -107,8 +100,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
CostMs: time.Since(time.Now()).Milliseconds(),
RequestPayload: &requestPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,

View File

@@ -35,14 +35,17 @@ type asyncWorker struct {
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
var (
body = task.RequestPayload.Body // 核心请求参数
maxRetry = model.RetryTimes // 重试次数
startTime = time.Now()
modelMessages = map[string]any{}
body = task.RequestPayload.Body
maxRetry = model.RetryTimes
startTime = time.Now()
result map[string]any
err error
)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================
// 1) 分布式并发控制
// ============================================
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
@@ -53,101 +56,91 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
}
if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Id: task.Id,
},
State: public.TaskStatusPending,
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// ============================================
// 2) 调用模型
// ============================================
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err := w.callModelStream(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
modelMessages, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
modelMessages, err = util.PullTaskResult(ctx, modelMessages, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
result, err = w.callModel(ctx, task, model, body)
}
// 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, modelMessages, task.TmpFile)
if err == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
}
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
modelMessages, err = w.parseAndRetry(ctx, modelMessages, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = modelMessages
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 3) 缓存临时文件
// ============================================
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
}
// ============================================
// 4) 解析校验 + 响应映射(可重试)
// ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = result
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 5) 上传 OSS可重试
// ============================================
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = w.uploadOSS(ctx, task)
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, err)
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
task.State = 3
task.State = public.TaskStatusFailed
task.ErrorMsg = err.Error()
task.Phase = 1
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
_, _ = dao.ModelGatewayTask.Update(ctx, task)
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 6) 成功回调
task.State = 2
// ============================================
// 6) 成功收尾
// ============================================
task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.ResultFile = &entity.ResultFile{
OssFile: oss.FileAddressPrefix + oss.FileURL,
FileType: oss.FileFormat,
FileSize: int64(oss.FileSize),
}
task.TextResult = modelMessages
task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return
@@ -159,10 +152,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat)
// 7) 删除临时文件
_ = os.Remove(task.TmpFile)
}
@@ -495,16 +487,6 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
// return mappedResponse, nil
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.ModelGatewayTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3