refactor(files): 优化文件处理和任务服务逻辑
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/service/queue"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -28,12 +27,15 @@ type taskService struct{}
|
||||
// Create 创建任务
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
taskID := uuid.NewString()
|
||||
startAt := time.Now()
|
||||
|
||||
// 1) 检查模型配置,并且获取模型
|
||||
// 1) 获取用户信息
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 检查模型配置
|
||||
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
TenantId: userInfo.TenantId,
|
||||
@@ -48,86 +50,63 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
||||
if limit > 0 {
|
||||
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
// TODO: 排队控制暂时关闭,后续需要时取消注释
|
||||
// limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, model.MaxConcurrency*2)
|
||||
// if limit > 0 {
|
||||
// ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, model.TimeoutSeconds)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// if !ok {
|
||||
// return nil, errors.New("任务排队已满,请稍后再试")
|
||||
// }
|
||||
// }
|
||||
|
||||
// 3) 构建任务实体
|
||||
task := &entity.ModelGatewayTask{
|
||||
ModelName: model.ModelName,
|
||||
TaskID: taskID,
|
||||
State: public.TaskStatusRunning,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
RequestPayload: &entity.RequestPayload{
|
||||
Body: req.RequestPayload,
|
||||
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
||||
},
|
||||
EpicycleId: req.EpicycleId,
|
||||
}
|
||||
|
||||
// 3) 插入任务记录
|
||||
requestPayload := entity.RequestPayload{
|
||||
Body: req.RequestPayload,
|
||||
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
|
||||
}
|
||||
task := new(entity.ModelGatewayTask)
|
||||
task.ModelName = model.ModelName
|
||||
task.TaskID = taskID
|
||||
task.State = public.TaskStatusRunning
|
||||
task.BizName = req.BizName
|
||||
task.CallbackURL = req.CallbackUrl
|
||||
task.RequestPayload = &requestPayload
|
||||
task.EpicycleId = req.EpicycleId
|
||||
// 4) 插入任务记录
|
||||
id, err := dao.ModelGatewayTask.Insert(ctx, task)
|
||||
if err != nil { // 入库失败:回滚闸门占位
|
||||
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
if err != nil {
|
||||
// TODO: 恢复排队逻辑后,此处需要回滚排队占位
|
||||
// queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
task.Id = id
|
||||
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
|
||||
// 5) 记录操作日志(非关键路径,失败不影响主流程)
|
||||
ip, ua := "", ""
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = utils.GetLocalIP()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
CostMs: time.Since(time.Now()).Milliseconds(),
|
||||
RequestPayload: &requestPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: "/task/createTask",
|
||||
HttpMethod: "POST",
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: task.RequestPayload,
|
||||
ResponsePayload: gdb.Map{"taskId": taskID},
|
||||
})
|
||||
|
||||
//// 5) 抢占任务:改为执行中
|
||||
//rows, err := dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
|
||||
// SQLBaseDO: beans.SQLBaseDO{Id: id},
|
||||
// State: public.TaskStatusRunning,
|
||||
//})
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
//if rows == 0 {
|
||||
// return nil, fmt.Errorf("任务不存在: id=%d", id)
|
||||
//}
|
||||
|
||||
// 6) 查询任务信息
|
||||
//task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
|
||||
// SQLBaseDO: beans.SQLBaseDO{Id: id},
|
||||
//})
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
// 7) 创建成功后立即异步尝试执行当前任务
|
||||
// 6) 异步执行任务
|
||||
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
|
||||
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
|
||||
Reference in New Issue
Block a user