refactor(files): 优化文件处理和任务服务逻辑
This commit is contained in:
@@ -19,7 +19,6 @@ import (
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
"model-gateway/service/queue"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
@@ -38,50 +37,28 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
||||
body = task.RequestPayload.Body
|
||||
maxRetry = model.RetryTimes
|
||||
startTime = time.Now()
|
||||
rawBytes []byte
|
||||
result map[string]any
|
||||
err error
|
||||
)
|
||||
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||
|
||||
g.Log().Infof(ctx, "[handleOne] 开始 taskId=%s model=%s", task.TaskID, task.ModelName)
|
||||
|
||||
// ============================================
|
||||
// 1) 分布式并发控制
|
||||
// ============================================
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
|
||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
|
||||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||
if err != nil {
|
||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
|
||||
State: public.TaskStatusPending,
|
||||
})
|
||||
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
|
||||
return
|
||||
}
|
||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||
|
||||
// ============================================
|
||||
// 2) 调用模型
|
||||
// 1) 调用模型
|
||||
// ============================================
|
||||
for attempt := 0; ; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] 调用模型 第%d次 taskId=%s", attempt, task.TaskID)
|
||||
g.Log().Infof(ctx, "[handleOne] 调模型重试 第%d次 taskId=%s", attempt, task.TaskID)
|
||||
time.Sleep(time.Duration(attempt) * time.Second)
|
||||
}
|
||||
|
||||
switch {
|
||||
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
|
||||
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
|
||||
if streamErr != nil {
|
||||
err = streamErr
|
||||
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||
continue
|
||||
rawBytes, err = w.callModelStream(ctx, task, model, body)
|
||||
if err == nil {
|
||||
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
}
|
||||
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
|
||||
result, err = w.callModel(ctx, task, model, body)
|
||||
if err == nil {
|
||||
@@ -95,24 +72,17 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
||||
break
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "Timeout") {
|
||||
if !strings.Contains(err.Error(), "Timeout") &&
|
||||
!strings.Contains(err.Error(), "InternalServiceError") {
|
||||
w.failTask(ctx, task, startTime, err.Error())
|
||||
return
|
||||
}
|
||||
g.Log().Warningf(ctx, "[执行任务][调用失败] taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||
|
||||
g.Log().Warningf(ctx, "[handleOne] 调模型失败 taskId=%s attempt=%d err=%v", task.TaskID, attempt, err)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 3) 缓存临时文件
|
||||
// ============================================
|
||||
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 4) 解析校验 + 响应映射(可重试)
|
||||
// 2) 解析校验 + 响应映射(可重试)
|
||||
// ============================================
|
||||
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
|
||||
if err != nil {
|
||||
@@ -122,34 +92,26 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 5) 上传 OSS(可重试)
|
||||
// 3) 上传 OSS(可重试)
|
||||
// ============================================
|
||||
var oss *gateway.UploadFileResponse
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
g.Log().Infof(ctx, "[handleOne] OSS上传重试 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
startUpload := time.Now()
|
||||
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
cost := time.Since(startUpload)
|
||||
g.Log().Infof(ctx, "本次上传耗时:%s", cost)
|
||||
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||
g.Log().Errorf(ctx, "[handleOne] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
task.State = public.TaskStatusFailed
|
||||
task.ErrorMsg = err.Error()
|
||||
task.Phase = 1
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, task)
|
||||
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 6) 成功收尾
|
||||
// 4) 成功收尾
|
||||
// ============================================
|
||||
task.State = public.TaskStatusSuccess
|
||||
task.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
@@ -159,21 +121,19 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTa
|
||||
FileSize: int64(oss.FileSize),
|
||||
}
|
||||
task.TextResult = result
|
||||
|
||||
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
|
||||
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
|
||||
g.Log().Errorf(ctx, "[handleOne] 更新DB失败 taskId=%s err=%v", task.TaskID, err)
|
||||
return
|
||||
}
|
||||
|
||||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), task)
|
||||
if req.EpicycleId != 0 {
|
||||
go gateway.TriggerPromptsCallback(util.AsyncCtx(ctx), task, req.EpicycleId)
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
|
||||
g.Log().Infof(ctx, "[handleOne] 成功 taskId=%s duration=%ds fileType=%s",
|
||||
task.TaskID, task.DurationSeconds, oss.FileFormat)
|
||||
|
||||
_ = os.Remove(task.TmpFile)
|
||||
}
|
||||
|
||||
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
@@ -313,12 +273,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTa
|
||||
// parseAndRetry 解析模型返回结果,并重试
|
||||
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
|
||||
}
|
||||
|
||||
// 1) 响应映射
|
||||
mapped, err := util.MapResponsePayload(model.ResponseMapping, body)
|
||||
if err != nil {
|
||||
@@ -372,7 +330,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
|
||||
}
|
||||
|
||||
var rawResp map[string]any
|
||||
if err := json.Unmarshal(rawData, &rawResp); err != nil {
|
||||
if err = json.Unmarshal(rawData, &rawResp); err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
|
||||
continue
|
||||
}
|
||||
@@ -553,15 +511,11 @@ func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[
|
||||
// return mappedResponse, nil
|
||||
// }
|
||||
|
||||
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
|
||||
// failTask 任务失败统一处理
|
||||
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
t.DurationSeconds = int64(time.Since(startTime).Seconds())
|
||||
_, err := dao.ModelGatewayTask.Update(ctx, t)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
|
||||
}
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), t)
|
||||
_, _ = dao.ModelGatewayTask.Update(ctx, t) // 更新任务状态
|
||||
go gateway.TriggerCallback(util.AsyncCtx(ctx), t) // 触发回调
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user