refactor(model): 重构模型实体和数据访问层
This commit is contained in:
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -21,13 +21,13 @@ var Task = &taskService{}
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
fmt.Printf("打印请求:%+v", req)
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
ctx = util.AsyncCtx(ctx)
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
"headers": util.ForwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
defer ticker.Stop()
|
||||
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
return true
|
||||
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
switch t.State {
|
||||
case 0:
|
||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: t.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user