refactor(service): 重构模型网关服务结构

This commit is contained in:
2026-06-11 17:58:49 +08:00
parent afd60caf56
commit 1c6c9bae14
34 changed files with 784 additions and 1223 deletions

View File

@@ -77,14 +77,14 @@ type CallbackPayload struct {
}
// TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx)
var resp struct{}
payload := CallbackPayload{
TaskId: t.TaskID,
State: t.State,
OssFile: t.OssFile,
FileType: t.FileType,
OssFile: t.ResultFile.OssFile,
FileType: t.ResultFile.FileType,
Messages: t.TextResult,
ErrorMsg: t.ErrorMsg,
}
@@ -111,7 +111,7 @@ type PromptsCallbackPayload struct {
}
// TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx)
var resp struct{}

View File

@@ -1,102 +0,0 @@
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
"model-gateway/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
t.ErrorMsg = "任务超时自动失败"
_, err = dao.Task.Update(ctx, t)
if err != nil {
g.Log().Errorf(ctx, "[清理] 标记任务失败: %v", err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
// 获取模型配置以得到 queue_limit / expected_seconds
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || m == nil {
continue
}
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
if limit > 0 {
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
if !ok {
continue
}
}
// retry_queue_max_seconds 控制失败重试的排队策略:
// - =0失败重试插队到队首
// - >0当任务从创建到现在的排队时长 >= maxSeconds则插队到队首否则仍放到队尾
now := time.Now()
enqueueAt := now
maxSeconds := t.RetryQueueMaxSeconds
if maxSeconds == 0 {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
} else if maxSeconds > 0 && t.CreatedAt != nil {
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
}
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -18,7 +18,7 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
var ModelGatewayModels = &modelService{}
type modelService struct{}
@@ -37,7 +37,7 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dt
}
// 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
id, err := dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
if err != nil {
return nil, err
}
@@ -56,27 +56,27 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err := dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
// 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err = dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err = dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
// Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
_, err := dao.ModelGatewayModels.Delete(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
@@ -91,7 +91,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
if g.IsEmpty(req.ID) {
req.Creator = user.UserName
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{
Id: req.ID,
Creator: user.UserName,
@@ -123,7 +123,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
req.Creator = user.UserName
// 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
models, total, err := dao.ModelGatewayModels.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
// UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
newModel, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil || newModel == nil {
@@ -146,7 +146,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
if err != nil {
return err
}
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
currentModel, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
@@ -161,7 +161,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
@@ -171,7 +171,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
}
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
@@ -185,7 +185,7 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
@@ -203,14 +203,14 @@ func (s *modelService) clearUserChatModel(ctx context.Context) error {
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
@@ -223,7 +223,7 @@ func (s *modelService) checkChatModelUnique(ctx context.Context) error {
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})

View File

@@ -43,14 +43,14 @@ func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes,
req.WindowSeconds = 3600 // 默认1小时
}
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel
var modelRows []*entity.ModelGatewayModel
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where("deleted_at IS NULL").
Where(entity.AsynchModelCol.Enabled, 1).
Where(entity.ModelGatewayModelCol.Enabled, 1).
Scan(&modelRows); err != nil {
return nil, err
}
modelMap := make(map[string]*entity.AsynchModel)
modelMap := make(map[string]*entity.ModelGatewayModel)
for _, m := range modelRows {
if m == nil || m.ModelName == "" {
continue

View File

@@ -2,36 +2,31 @@ package stat
import (
"context"
"model-gateway/model/entity"
"model-gateway/dao"
"model-gateway/model/dto"
)
type statService struct{}
var ModelGatewayLogsStat = &logsStatService{}
var Stat = &statService{}
type logsStatService struct{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
func (s *logsStatService) List(ctx context.Context, req *dto.ListModelStatReq) (*dto.ListModelStatRes, error) {
if req == nil {
req = &dto.ListModelStatReq{}
}
startDay, endDay := "", ""
var tenantID *int64
creator, modelName := "", ""
if req != nil {
startDay = req.StartDay
endDay = req.EndDay
tenantID = req.TenantID
creator = req.Creator
modelName = req.ModelName
if req.PageNum <= 0 {
req.PageNum = 1
}
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
if req.PageSize <= 0 {
req.PageSize = 10
}
list, total, err := dao.ModelGatewayLogsStat.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayLogsStat{
Creator: req.Creator,
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}

View File

@@ -17,25 +17,27 @@ import (
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
var Task = &taskService{}
var ModelGatewayTask = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
var (
startAt = time.Now()
taskID = uuid.NewString()
)
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
@@ -66,19 +68,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
}
storedPayload := map[string]any{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
requestPayload := entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
State: public.TaskStatusPending,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
RequestPayload: &requestPayload,
EpicycleId: req.EpicycleId,
})
if err != nil { // 入库失败:回滚闸门占位
@@ -97,7 +97,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
@@ -109,20 +109,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
RequestPayload: &requestPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
@@ -130,10 +127,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.ResultFile.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)
for _, t := range list {
if t == nil {
continue
}
if t.State != public.BuildTypeNode {
continue
}
_ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id)
// 为了本次返回一致性,内存里也更新
t.State = public.TaskStatusDownloaded
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.ResultFile.OssFile,
TextResult: t.TextResult,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) {
if req.PageNum <= 0 {
req.PageNum = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
ModelName: req.ModelName,
BizName: req.BizName,
State: req.State,
TaskID: req.TaskID,
})
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}
// ModelTaskCallback 模型异步任务的回调通知
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
@@ -167,7 +250,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
@@ -176,7 +259,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
@@ -206,100 +289,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
Results: results,
}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

View File

@@ -24,7 +24,6 @@ import (
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -34,11 +33,13 @@ type asyncWorker struct {
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
maxRetry := model.RetryTimes // 重试次
startTime := time.Now()
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
var (
body = task.RequestPayload.Body // 核心请求参
maxRetry = model.RetryTimes // 重试次数
startTime = time.Now()
modelMessages = map[string]any{}
)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// 1) 分布式并发控制
@@ -51,8 +52,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
return
}
if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Id: task.Id,
},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
@@ -65,24 +71,24 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
modelMessages, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body)
modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
modelMessages, err = util.PullTaskResult(ctx, modelMessages, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
default:
body, err = w.callModel(ctx, task, model, body)
modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
@@ -90,20 +96,20 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
}
// 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
tmpPath, err := util.SaveTempFileByType(task.TaskID, modelMessages, task.TmpFile)
if err == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.Task.Update(ctx, task)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
}
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
modelMessages, err = w.parseAndRetry(ctx, modelMessages, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = body
task.TextResult = modelMessages
w.failTask(ctx, task, startTime, err.Error())
return
}
@@ -123,9 +129,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
if attempt == maxRetry {
task.State = 3
task.ErrorMsg = err.Error()
task.FinishedAt = gtime.Now()
task.Phase = 1
_, err = dao.Task.Update(ctx, task)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
@@ -137,12 +142,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
// 6) 成功回调
task.State = 2
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL
task.FileType = oss.FileFormat
task.TextResult = body
task.FileSize = int64(oss.FileSize)
if _, err = dao.Task.Update(ctx, task); err != nil {
task.ResultFile = &entity.ResultFile{
OssFile: oss.FileAddressPrefix + oss.FileURL,
FileType: oss.FileFormat,
FileSize: int64(oss.FileSize),
}
task.TextResult = modelMessages
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return
}
@@ -161,7 +167,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
}
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
var data []byte
var err error
@@ -173,8 +179,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
data, err = InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
@@ -182,7 +187,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.Task.Update(ctx, task)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
@@ -201,7 +206,7 @@ type asyncResult struct {
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body)
if err != nil {
@@ -246,7 +251,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
// callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
var data []byte
var err error
@@ -261,8 +266,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
// 2) 没有可用数据,调用模型
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
data, err = InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
@@ -273,7 +277,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, err = dao.Task.Update(ctx, task)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
@@ -297,7 +301,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -316,7 +320,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 2) 先存 token 到数据库,防止后续失败丢失
if _, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
_, err = dao.Task.Update(ctx, &entity.AsynchTask{
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
@@ -344,9 +348,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
}
// 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
reqBody := util.GetModelBody(task.RequestPayload)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
@@ -354,7 +359,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil {
if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
@@ -366,18 +371,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2)构建请求 URL 和超时
// 3)构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3)构建 HTTP 请求
// 4)构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
@@ -401,31 +409,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 4)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
// 5)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
if modelKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey)
if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+model.ApiKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 5)发送请求
// 6)发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 6)读取响应体
// 7)读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 7)检查 HTTP 状态码
// 8)检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
@@ -488,7 +496,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.ModelGatewayTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
@@ -498,19 +506,14 @@ func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gat
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_, err := dao.Task.Update(ctx, t)
_, err := dao.ModelGatewayTask.Update(ctx, t)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}