refactor(model): 重构模型实体和数据访问层

This commit is contained in:
2026-05-21 10:41:37 +08:00
parent a080a5536d
commit 170568e03e
35 changed files with 903 additions and 1072 deletions

View File

@@ -3,7 +3,7 @@ package service
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"time"
"model-gateway/dao"
@@ -21,13 +21,13 @@ var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
fmt.Printf("打印请求:%+v", req)
startAt := time.Now()
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
ctx = util.AsyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
"headers": util.ForwardHeaders(ctx),
}
t := &entity.AsynchTask{
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
defer ticker.Stop()
tryRun := func() bool {
t, err := dao.Task.GetByTaskID(ctx, taskID)
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
return true
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
switch t.State {
case 0:
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}