diff --git a/common/util/mapping.go b/common/util/mapping.go index 58d82ea..b071b2e 100644 --- a/common/util/mapping.go +++ b/common/util/mapping.go @@ -21,7 +21,7 @@ import ( ) // ParseAndValidate 解析并校验结果 -func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) { +func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) { // 1) 解析 content 字符串为 rounds 数组 contentVal, ok := raw[model.ResponseBody] if !ok { @@ -94,53 +94,6 @@ func ParseStructResult(raw map[string]any, responseBody string) map[string]any { } } -// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 -// raw 必须包含 "rounds" 字段,格式为 []map[string]any -func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { - // 1) 获取 rounds - roundsRaw, ok := raw["rounds"] - if !ok { - return fmt.Errorf("缺少 rounds 字段") - } - rounds, ok := roundsRaw.([]any) - if !ok { - return fmt.Errorf("rounds 不是数组") - } - if len(rounds) == 0 { - return fmt.Errorf("rounds 数组为空") - } - - // 2) 没有配置必填字段,跳过 - if len(model.RequiredFields) == 0 { - return nil - } - - // 3) 逐条校验 - for i, r := range rounds { - round, ok := r.(map[string]any) - if !ok { - continue - } - for _, field := range model.RequiredFields { - if gjson.New(round).Get(field).IsNil() { - return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field) - } - } - } - - return nil -} - -// validateRequiredFields 校验单个 round 对象的必选字段 -func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error { - for _, field := range requiredFields { - if gjson.New(round).Get(field).IsNil() { - return fmt.Errorf("%s 缺少必填字段: %s", prefix, field) - } - } - return nil -} - // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // head_msg 格式示例: // @@ -198,16 +151,17 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri return mapped, nil } -// GetModelBody 获取数据库中保存的模型信息 -func GetModelBody(v map[string]any) map[string]any { - if v == nil { - return nil - } - if p, ok := v["body"]; ok { - return gconv.Map(p) - } - return v -} +// +//// GetModelBody 获取数据库中保存的模型信息 +//func GetModelBody(v map[string]any) map[string]any { +// if v == nil { +// return nil +// } +// if p, ok := v["body"]; ok { +// return gconv.Map(p) +// } +// return v +//} // BodyToQuery 将 body 转为 url.Values func BodyToQuery(payload map[string]any) (url.Values, error) { diff --git a/consts/public/public.go b/consts/public/public.go index a5c794c..f20e1dc 100644 --- a/consts/public/public.go +++ b/consts/public/public.go @@ -6,6 +6,14 @@ const ( CallModeStream = 2 // 流式调用 ) +const ( + TaskStatusPending = 0 // 排队中 + TaskStatusRunning = 1 // 执行中 + TaskStatusSuccess = 2 // 成功 + TaskStatusFailed = 3 // 失败 + TaskStatusDownloaded = 4 // 已下载 +) + const ( BuildTypePrompt = 1 //提示词构建 BuildTypeNode = 2 //节点构建 diff --git a/consts/public/table_name.go b/consts/public/table_name.go index 16cf1c0..a7f89cb 100644 --- a/consts/public/table_name.go +++ b/consts/public/table_name.go @@ -5,8 +5,8 @@ const ( ) const ( - TableNameModel = "asynch_models" // 模型表 - TableNameTask = "asynch_task" // 任务表 - TableNameOpLog = "logs_model_op" // 操作日志表 - TableNameStat = "logs_model_stat" // 按天统计表(请求次数) + TableNameModel = "model_gateway_models" // 模型表 + TableNameTask = "model_gateway_task" // 任务表 + TableNameOpLog = "model_gateway_logs_op" // 操作日志表 + TableNameStat = "model_gateway_logs_stat" // 按天统计表 ) diff --git a/controller/stat_controller.go b/controller/model_gateway_logs_stat_controller.go similarity index 68% rename from controller/stat_controller.go rename to controller/model_gateway_logs_stat_controller.go index 64325e4..636c026 100644 --- a/controller/stat_controller.go +++ b/controller/model_gateway_logs_stat_controller.go @@ -7,12 +7,12 @@ import ( "model-gateway/model/dto" ) -type stat struct{} +// ModelGatewayLogsStat 统计控制器 +var ModelGatewayLogsStat = new(stat) -// Stat 统计控制器 -var Stat = new(stat) +type stat struct{} // ListModelStat 统计列表 func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { - return statService.Stat.List(ctx, req) + return statService.ModelGatewayLogsStat.List(ctx, req) } diff --git a/controller/model_controller.go b/controller/model_gateway_models_controller.go similarity index 79% rename from controller/model_controller.go rename to controller/model_gateway_models_controller.go index e81a39f..d7a0810 100644 --- a/controller/model_controller.go +++ b/controller/model_gateway_models_controller.go @@ -7,36 +7,36 @@ import ( "model-gateway/service/queue" ) -type model struct{} +// ModelGatewayModels 模型配置控制器 +var ModelGatewayModels = new(model) -// Model 模型配置控制器 -var Model = new(model) +type model struct{} // CreateModel 添加配置 func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { - return modelService.Model.Create(ctx, req) + return modelService.ModelGatewayModels.Create(ctx, req) } // UpdateModel 更改配置 func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) { - err = modelService.Model.Update(ctx, req) + err = modelService.ModelGatewayModels.Update(ctx, req) return } // DeleteModel 删除配置 func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) { - err = modelService.Model.Delete(ctx, req) + err = modelService.ModelGatewayModels.Delete(ctx, req) return } // GetModel 获取配置详情 func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) { - return modelService.Model.Get(ctx, req) + return modelService.ModelGatewayModels.Get(ctx, req) } // ListModel 配置列表 func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { - return modelService.Model.List(ctx, req) + return modelService.ModelGatewayModels.List(ctx, req) } // AutoTune 动态调参(由上层定时任务每小时触发一次) @@ -56,11 +56,11 @@ func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res // UpdateChatModel 更新是否为聊天模型 func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) { - err = modelService.Model.UpdateChatModel(ctx, req) + err = modelService.ModelGatewayModels.UpdateChatModel(ctx, req) return } // GetIsChatModel 获取当前会话模型 func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) { - return modelService.Model.GetIsChatModel(ctx) + return modelService.ModelGatewayModels.GetIsChatModel(ctx) } diff --git a/controller/task_controller.go b/controller/model_gateway_task_controller.go similarity index 53% rename from controller/task_controller.go rename to controller/model_gateway_task_controller.go index 9557c76..1ca0b0c 100644 --- a/controller/task_controller.go +++ b/controller/model_gateway_task_controller.go @@ -2,48 +2,42 @@ package controller import ( "context" - "model-gateway/service/job" taskService "model-gateway/service/task" "model-gateway/model/dto" ) -type task struct{} +// ModelGatewayTask 任务控制器 +var ModelGatewayTask = new(task) -// Task 任务控制器 -var Task = new(task) +type task struct{} // CreateTask 根据 modelName 创建异步任务,返回 taskId func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { - return taskService.Task.Create(ctx, req) + return taskService.ModelGatewayTask.Create(ctx, req) } -// ModelTaskCallback 接收模型异步任务的回调通知 -func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) { - return taskService.Task.ModelTaskCallback(ctx, req) -} - -// QueryPendingTasks 批量轮询进行中的异步任务 -func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) { - return taskService.Task.QueryPendingTasks(ctx, req) -} - -// GetTaskResult 获取任务结果(只返回 oss 地址 + state) +// GetTaskResult 获取单条任务结果(返回 *dto.GetTaskResultRes) func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) { - return taskService.Task.GetResult(ctx, req.TaskID) + return taskService.ModelGatewayTask.GetResult(ctx, req.TaskID) } -// GetTaskBatch 批量查询任务(成功任务标记为已下载) +// GetTaskBatch 批量查询任务(返回 *[]dto.GetTaskBatchItem) func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { - return taskService.Task.GetBatch(ctx, req) + return taskService.ModelGatewayTask.GetBatch(ctx, req) } // ListTask 任务列表分页查询 func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { - return taskService.Task.List(ctx, req) + return taskService.ModelGatewayTask.List(ctx, req) } -// CleanWork 手动触发一次 cleaner(由上层定时任务调用) -func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) { - return job.Cleaner.RunOnce(ctx) +// ModelTaskCallback 接收模型异步任务的回调通知 —— 待调整 +func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) { + return taskService.ModelGatewayTask.ModelTaskCallback(ctx, req) +} + +// QueryPendingTasks 批量轮询进行中的异步任务 —— 待调整 +func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) { + return taskService.ModelGatewayTask.QueryPendingTasks(ctx, req) } diff --git a/dao/model_gateway_logs_op.go b/dao/model_gateway_logs_op.go new file mode 100644 index 0000000..98d05e4 --- /dev/null +++ b/dao/model_gateway_logs_op.go @@ -0,0 +1,23 @@ +package dao + +import ( + "context" + + "model-gateway/consts/public" + "model-gateway/model/entity" + + "gitea.redpowerfuture.com/red-future/common/db/gfdb" +) + +var ModelGatewayLogsOp = &modelGatewayLogsOpDao{} + +type modelGatewayLogsOpDao struct{} + +// Insert 插入操作日志 +func (d *modelGatewayLogsOpDao) Insert(ctx context.Context, req *entity.ModelGatewayLogsOp) (int64, error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).Insert(req) + if err != nil { + return 0, err + } + return r.LastInsertId() +} diff --git a/dao/model_gateway_logs_stat_dao.go b/dao/model_gateway_logs_stat_dao.go new file mode 100644 index 0000000..54ae746 --- /dev/null +++ b/dao/model_gateway_logs_stat_dao.go @@ -0,0 +1,52 @@ +package dao + +import ( + "context" + "time" + + "model-gateway/consts/public" + "model-gateway/model/entity" + + "gitea.redpowerfuture.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/util/gconv" +) + +var ModelGatewayLogsStat = &modelGatewayLogsStatDao{} + +type modelGatewayLogsStatDao struct{} + +// IncRequestCount 原子累加:按天+租户+创建人+模型 +1 +func (d *modelGatewayLogsStatDao) IncRequestCount(ctx context.Context, day time.Time, tenantId uint64, creator, modelName string) error { + _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat). + Data(&entity.ModelGatewayLogsStat{ + Day: gtime.New(day), + TenantId: tenantId, + Creator: creator, + ModelName: modelName, + RequestCount: 1, + }). + OnDuplicate("request_count", "request_count+1"). + Insert() + return err +} + +// List 分页查询统计 +func (d *modelGatewayLogsStatDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayLogsStat) (list []*entity.ModelGatewayLogsStat, total int64, err error) { + model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat). + OmitEmpty(). + Where(entity.ModelGatewayLogsStatCols.Creator, req.Creator). + WhereLike(entity.ModelGatewayLogsStatCols.ModelName, "%"+req.ModelName+"%"). + OrderDesc(entity.ModelGatewayLogsStatCols.Day). + OrderDesc(entity.ModelGatewayLogsStatCols.RequestCount) + if pageNum > 0 && pageSize > 0 { + model = model.Page(pageNum, pageSize) + } + r, totalInt, err := model.AllAndCount(false) + if err != nil { + return nil, 0, err + } + total = gconv.Int64(totalInt) + err = r.Structs(&list) + return +} diff --git a/dao/model_dao.go b/dao/model_gateway_models_dao.go similarity index 70% rename from dao/model_dao.go rename to dao/model_gateway_models_dao.go index 151617a..be0c423 100644 --- a/dao/model_dao.go +++ b/dao/model_gateway_models_dao.go @@ -9,71 +9,64 @@ import ( "gitea.redpowerfuture.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/util/gconv" ) -var Model = &modelDao{} +var ModelGatewayModels = &modelGatewayModelsDao{} -type modelDao struct{} +type modelGatewayModelsDao struct{} // Insert 插入 -func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) { - m := new(entity.AsynchModel) - err = gconv.Struct(req, &m) +func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req) if err != nil { - return - } - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). - Insert(m) - if err != nil { - return + return 0, err } return r.LastInsertId() } // Update 更新 -func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { +func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). OmitEmpty(). - Data(&req). - Where(entity.AsynchModelCol.Id, req.Id). + Data(req). + Where(entity.ModelGatewayModelCol.Id, req.Id). Update() if err != nil { - return + return 0, err } return r.RowsAffected() } // Delete 删除 -func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { +func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). OmitEmpty(). - Where(entity.AsynchModelCol.Id, req.Id). + Where(entity.ModelGatewayModelCol.Id, req.Id). Delete() if err != nil { - return + return 0, err } return r.RowsAffected() } // Get 获取模型 -func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { +func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). OmitEmpty(). - Where(entity.AsynchModelCol.Id, req.Id). - Where(entity.AsynchModelCol.Creator, req.Creator). - Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). - Where(entity.AsynchModelCol.ModelName, req.ModelName). + Where(entity.ModelGatewayModelCol.Id, req.Id). + Where(entity.ModelGatewayModelCol.Creator, req.Creator). + Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). Fields(fields).One() if err != nil { - return + return nil, err } + var m entity.ModelGatewayModel err = r.Struct(&m) - return + return &m, err } //// Get 按ID获取(带租户隔离,只查当前租户) -//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { +//func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { // var whereCondition strings.Builder // var queryParams []interface{} // if !g.IsEmpty(req.Id) { @@ -108,25 +101,25 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s // return //} -// GetByAcrossTenant 按ID获取(跨租户,查所有租户) -func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { +// GetByAcrossTenant 跨租户查询 +func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). NoTenantId(ctx). OmitEmpty(). - Where(entity.AsynchModelCol.Id, req.Id). - Where(entity.AsynchModelCol.Creator, req.Creator). - Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). - Where(entity.AsynchModelCol.ModelName, req.ModelName). + Where(entity.ModelGatewayModelCol.Id, req.Id). + Where(entity.ModelGatewayModelCol.Creator, req.Creator). + Where(entity.ModelGatewayModelCol.ModelName, req.ModelName). Fields(fields).One() if err != nil { - return + return nil, err } + var m entity.ModelGatewayModel err = r.Struct(&m) - return + return &m, err } // GetByCreatorAndPlatform 按创建者、平台获取 -func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { +func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) { sql := ` SELECT DISTINCT ON (model_name) * FROM asynch_models @@ -186,7 +179,7 @@ WHERE deleted_at IS NULL } // GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文 -func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) { +func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) { r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, "SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1", tenantId, modelName, @@ -197,7 +190,7 @@ func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, if r.IsEmpty() { return nil, nil } - var list []*entity.AsynchModel + var list []*entity.ModelGatewayModel if err := r.Structs(&list); err != nil { return nil, err } diff --git a/dao/model_gateway_task_dao.go b/dao/model_gateway_task_dao.go new file mode 100644 index 0000000..7d9c209 --- /dev/null +++ b/dao/model_gateway_task_dao.go @@ -0,0 +1,159 @@ +package dao + +import ( + "context" + "fmt" + "model-gateway/consts/public" + "model-gateway/model/entity" + + "gitea.redpowerfuture.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/util/gconv" +) + +var ModelGatewayTask = &modelGatewayTaskDao{} + +type modelGatewayTaskDao struct{} + +// Insert 插入 +func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) { + m := new(entity.ModelGatewayTask) + err = gconv.Struct(req, &m) + if err != nil { + return + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m) + if err != nil { + return + } + return r.LastInsertId() +} + +// Update 更新(按ID) +func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + OmitEmpty(). + Data(req). + Where(entity.ModelGatewayTaskCol.Id, req.Id). + Update() + if err != nil { + return + } + return r.RowsAffected() +} + +// Get 获取(按TaskID 或 ID) +func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + OmitEmpty(). + Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID). + Where(entity.ModelGatewayTaskCol.Id, req.Id). + One() + if err != nil { + return + } + err = r.Struct(&m) + return +} + +// List 分页查询 +func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) { + model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + OmitEmpty(). + Where(entity.ModelGatewayTaskCol.Creator, req.Creator). + Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%"). + Where(entity.ModelGatewayTaskCol.BizName, req.BizName). + Where(entity.ModelGatewayTaskCol.State, req.State). + Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID). + OrderDesc(entity.ModelGatewayTaskCol.CreatedAt) + if pageNum > 0 && pageSize > 0 { + model = model.Page(pageNum, pageSize) + } + r, totalInt, err := model.AllAndCount(false) + if err != nil { + return nil, 0, err + } + total = gconv.Int64(totalInt) + err = r.Structs(&list) + return +} + +// Delete 删除(软删,按ID) +func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + Where(entity.ModelGatewayTaskCol.Id, req.Id). + Delete() + if err != nil { + return + } + return r.RowsAffected() +} + +// ListByTaskIDs 批量查询 +func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) { + if len(taskIDs) == 0 { + return nil, nil + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs). + All() + if err != nil { + return nil, err + } + err = r.Structs(&list) + return +} + +// MarkDownloadedByID 标记已下载 +func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error { + _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + Where(entity.ModelGatewayTaskCol.Id, id). + Where(entity.ModelGatewayTaskCol.State, 2). + Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}). + Update() + return err +} + +// GetPendingAsyncTasks 获取进行中的异步任务 +func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) { + var tasks []*entity.ModelGatewayTask + err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + Where(entity.ModelGatewayTaskCol.State, 1). + Limit(limit). + Scan(&tasks) + return tasks, err +} + +// ======================== 事务抢占 ======================== + +// ClaimByID 按主键抢占,返回抢占后的任务 +func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) { + var task entity.ModelGatewayTask + err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { + r, err := tx.Model(public.TableNameTask). + Where(entity.ModelGatewayTaskCol.Id, id). + Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). + Limit(1). + LockUpdate(). + One() + if err != nil { + return err + } + if r.IsEmpty() { + return fmt.Errorf("任务已被抢占或不存在: id=%d", id) + } + if err := r.Struct(&task); err != nil { + return err + } + _, err = tx.Model(public.TableNameTask). + Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}). + Where(entity.ModelGatewayTaskCol.Id, id). + OmitEmpty(). + Update() + return err + }) + if err != nil { + return nil, err + } + return &task, nil +} diff --git a/dao/op_log_dao.go b/dao/op_log_dao.go deleted file mode 100644 index ca2e651..0000000 --- a/dao/op_log_dao.go +++ /dev/null @@ -1,30 +0,0 @@ -package dao - -import ( - "context" - - "model-gateway/consts/public" - "model-gateway/model/entity" - - "gitea.redpowerfuture.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/util/gconv" -) - -type opLogDao struct{} - -var OpLog = &opLogDao{} - -// Insert 插入 -func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) { - m := new(entity.LogsModelOp) - err = gconv.Struct(req, &m) - if err != nil { - return - } - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog). - Insert(m) - if err != nil { - return 0, err - } - return r.LastInsertId() -} diff --git a/dao/stat_dao.go b/dao/stat_dao.go deleted file mode 100644 index 996c4f2..0000000 --- a/dao/stat_dao.go +++ /dev/null @@ -1,60 +0,0 @@ -package dao - -import ( - "context" - "fmt" - "time" - - "model-gateway/consts/public" - "model-gateway/model/entity" - - "gitea.redpowerfuture.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/os/gtime" -) - -type statDao struct{} - -var Stat = &statDao{} - -// IncRequestCount 原子累加(支持分布式/多协程):按天+租户+创建人+模型 +1 -func (d *statDao) IncRequestCount(ctx context.Context, day time.Time, tenantId int64, creator, modelName string) error { - sql := fmt.Sprintf(` -INSERT INTO %s(day, tenant_id, creator, model_name, request_count, created_at, updated_at) -VALUES(?, ?, ?, ?, 1, NOW(), NOW()) -ON CONFLICT (day, tenant_id, creator, model_name) -DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`, - public.TableNameStat, public.TableNameStat, - ) - _, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName) - return err -} - -func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) { - m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1") - if startDay != "" { - m = m.Where("day >= ?", startDay) - } - if endDay != "" { - m = m.Where("day <= ?", endDay) - } - if tenantId != nil { - m = m.Where("tenant_id = ?", *tenantId) - } - if creator != "" { - m = m.WhereLike("creator", "%"+creator+"%") - } - if modelName != "" { - m = m.WhereLike("model_name", "%"+modelName+"%") - } - m = m.OrderDesc("day").OrderDesc("request_count") - if pageNum > 0 && pageSize > 0 { - m = m.Page(pageNum, pageSize) - } - r, totalInt, err := m.AllAndCount(false) - if err != nil { - return nil, 0, err - } - total = int64(totalInt) - err = r.Structs(&list) - return -} diff --git a/dao/task_dao.go b/dao/task_dao.go deleted file mode 100644 index dcb7766..0000000 --- a/dao/task_dao.go +++ /dev/null @@ -1,125 +0,0 @@ -package dao - -import ( - "context" - "model-gateway/consts/public" - "model-gateway/model/entity" - - "gitea.redpowerfuture.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/os/gtime" - "github.com/gogf/gf/v2/util/gconv" -) - -var Task = &taskDao{} - -type taskDao struct{} - -// Insert 插入 -func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) { - m := new(entity.AsynchTask) - err = gconv.Struct(req, &m) - if err != nil { - return - } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Insert(m) - if err != nil { - return - } - return r.LastInsertId() -} - -// Update 更新 -func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). - OmitEmpty(). - Data(&req). - Where(entity.AsynchTaskCol.Id, req.Id). - Where(entity.AsynchTaskCol.TaskID, req.TaskID). - Update() - if err != nil { - return - } - return r.RowsAffected() -} - -// Get 获取 -func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). - OmitEmpty(). - Where(entity.AsynchTaskCol.TaskID, req.TaskID). - Fields(fields).One() - if err != nil { - return - } - err = r.Struct(&m) - return -} - -// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据) -func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) { - if len(taskIDs) == 0 { - return nil, nil - } - r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). - OmitEmpty(). - WhereIn(entity.AsynchTaskCol.TaskID, taskIDs). - All() - if err != nil { - return nil, err - } - err = r.Structs(&m) - return -} - -// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间 -func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error { - data := gdb.Map{ - entity.AsynchTaskCol.State: 4, - entity.AsynchTaskCol.ExpireAt: expireAt, - entity.AsynchTaskCol.Updater: "", - } - _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.Id, id). - Where(entity.AsynchTaskCol.State, 2). - Data(data). - Update() - return err -} - -// List 任务分页查询(受 gfdb 租户 Hook 影响) -func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) { - m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL") - if modelNameLike != "" { - m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%") - } - if taskIDLike != "" { - m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%") - } - if state != nil { - m = m.Where(entity.AsynchTaskCol.State, *state) - } - m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt) - if pageNum > 0 && pageSize > 0 { - m = m.Page(pageNum, pageSize) - } - r, totalInt, err := m.AllAndCount(false) - if err != nil { - return nil, 0, err - } - total = gconv.Int64(totalInt) - err = r.Structs(&list) - return -} - -// GetPendingAsyncTasks 获取进行中的异步任务 -func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - var tasks []*entity.AsynchTask - err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). - Where("state", 1). - Where("deleted_at IS NULL"). - Limit(limit). - Scan(&tasks) - return tasks, err -} diff --git a/dao/task_dao_bg.go b/dao/task_dao_bg.go deleted file mode 100644 index 1141ac9..0000000 --- a/dao/task_dao_bg.go +++ /dev/null @@ -1,220 +0,0 @@ -package dao - -import ( - "context" - "fmt" - "time" - - "model-gateway/consts/public" - "model-gateway/model/entity" - - "gitea.redpowerfuture.com/red-future/common/db/gfdb" - "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/os/gtime" -) - -// ======================== 查询辅助 ======================== - -const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file` - -// ======================== 通用 CRUD ======================== - -// UpdateFields 更新指定字段(map 版,用于必须更新零值的场景) -func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway). - Model(ctx, public.TableNameTask). - Data(data). - Where(entity.AsynchTaskCol.Id, id). - Update() - if err != nil { - return 0, err - } - return r.RowsAffected() -} - -// execUpdate 内部辅助:执行原生 UPDATE,自动补 updated_at -func execUpdate(ctx context.Context, sql string, args ...any) error { - _, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...) - return err -} - -// ======================== 事务抢占 ======================== - -func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) { - var tasks []*entity.AsynchTask - err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - sql := fmt.Sprintf( - `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, - taskColumns, public.TableNameTask, where, - ) - r, err := tx.GetOne(sql, args...) - if err != nil { - return err - } - if r.IsEmpty() { - return nil - } - var task entity.AsynchTask - if err := r.Struct(&task); err != nil { - return err - } - now := time.Now() - _, err = tx.Exec( - fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), - now, now, task.Id, - ) - if err != nil { - return err - } - tasks = []*entity.AsynchTask{&task} - return nil - }) - return tasks, err -} - -// ClaimPendingGlobal 批量抢占 pending 任务 -func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) { - if batchSize <= 0 { - batchSize = 1 - } - var tasks []*entity.AsynchTask - err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - sql := fmt.Sprintf( - `SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, - taskColumns, public.TableNameTask, batchSize, - ) - r, err := tx.GetAll(sql) - if err != nil { - return err - } - if r.IsEmpty() { - return nil - } - if err := r.Structs(&tasks); err != nil { - return err - } - now := time.Now() - for _, t := range tasks { - _, err = tx.Exec( - fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), - now, now, t.Id, - ) - if err != nil { - return err - } - } - return nil - }) - return tasks, err -} - -// ClaimPendingByTaskIDGlobal 按 task_id 抢占 -func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) { - if taskID == "" { - return nil, nil - } - tasks, err := claimTasks(ctx, "AND task_id = ?", taskID) - if err != nil || len(tasks) == 0 { - return nil, err - } - return tasks[0], nil -} - -// ======================== 业务更新方法 ======================== - -// RollbackToPendingGlobal 回滚到 pending 状态 -func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error { - // state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发 - return execUpdate(ctx, - fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), - id, - ) -} - -// IncRetryCountGlobal 重试计数 +1 -func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error { - return execUpdate(ctx, - fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), - id, - ) -} - -// RequeueForRetryGlobal 重新入队 -func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error { - return execUpdate(ctx, - fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), - enqueueAt, id, - ) -} - -// ======================== 列表查询 ======================== - -// ListExpiredDownloadedGlobal 查询已过期下载的任务 -func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, - fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), - gtime.Now(), clampLimit(limit, 200), - ) -} - -// ListFailedRetryableGlobal 查询可重试的失败任务 -func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, - fmt.Sprintf( - `SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, - public.TableNameTask, public.TableNameModel, - ), - clampLimit(limit, 200), - ) -} - -// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务 -func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, - fmt.Sprintf( - `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, - public.TableNameTask, public.TableNameModel, - ), - clampLimit(limit, 200), - ) -} - -// ListTimeoutTasksGlobal 查询超时任务 -func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) { - return queryTasks(ctx, - fmt.Sprintf( - `SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, - public.TableNameTask, public.TableNameModel, - ), - clampLimit(limit, 200), - ) -} - -// HardDeleteByIDGlobal 物理删除任务 -func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error { - return execUpdate(ctx, - fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), - id, - ) -} - -// ======================== 内部辅助 ======================== - -func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) { - r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) - if err != nil { - return nil, err - } - var list []*entity.AsynchTask - if err = r.Structs(&list); err != nil { - return nil, err - } - return list, nil -} - -func clampLimit(limit, defaultVal int) int { - if limit <= 0 { - return defaultVal - } - return limit -} diff --git a/main.go b/main.go index 70eee19..d00e8fe 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,6 @@ package main import ( "context" "model-gateway/model/dto" - "model-gateway/service/job" "model-gateway/service/task" "os" "os/signal" @@ -27,9 +26,9 @@ func main() { // 注册路由 http.RouteRegister([]interface{}{ - controller.Model, - controller.Task, - controller.Stat, + controller.ModelGatewayModels, + controller.ModelGatewayTask, + controller.ModelGatewayLogsStat, }) // 本地调试:可选自动触发 worker/cleaner(由配置文件控制) @@ -47,26 +46,6 @@ func main() { } func startAutoRunner(ctx context.Context) { - // cleaner - if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() { - interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int() - if interval <= 0 { - interval = 30 - } - ticker := time.NewTicker(time.Duration(interval) * time.Second) - go func() { - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - _, _ = job.Cleaner.RunOnce(ctx) - } - } - }() - } - // queryPending if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() { interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int() @@ -79,7 +58,7 @@ func startAutoRunner(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil { + if _, err := task.ModelGatewayTask.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil { g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err) } } diff --git a/model/dto/model_gateway_logs_stat_dto.go b/model/dto/model_gateway_logs_stat_dto.go new file mode 100644 index 0000000..2f40f36 --- /dev/null +++ b/model/dto/model_gateway_logs_stat_dto.go @@ -0,0 +1,20 @@ +package dto + +import "github.com/gogf/gf/v2/frame/g" + +// ListModelStatReq 统计列表 +type ListModelStatReq struct { + g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"` + PageNum int `p:"pageNum" json:"pageNum" dc:"页码(默认1)"` + PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数(默认10)"` + StartDay string `p:"startDay" json:"startDay" dc:"开始日期(YYYY-MM-DD,可选)"` + EndDay string `p:"endDay" json:"endDay" dc:"结束日期(YYYY-MM-DD,可选)"` + TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID(可选)"` + Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"` + ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"` +} + +type ListModelStatRes struct { + List any `json:"list" dc:"列表数据"` + Total int64 `json:"total" dc:"总数"` +} diff --git a/model/dto/model_dto.go b/model/dto/model_gateway_models_dto.go similarity index 99% rename from model/dto/model_dto.go rename to model/dto/model_gateway_models_dto.go index 7696e93..384a570 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_gateway_models_dto.go @@ -103,7 +103,7 @@ type GetModelReq struct { } type GetModelRes struct { - Model *entity.AsynchModel `json:"model" dc:"模型配置详情"` + Model *entity.ModelGatewayModel `json:"model" dc:"模型配置详情"` } // ListModelReq 配置列表 diff --git a/model/dto/task_dto.go b/model/dto/model_gateway_task_dto.go similarity index 86% rename from model/dto/task_dto.go rename to model/dto/model_gateway_task_dto.go index ebc3eb8..9e827c3 100644 --- a/model/dto/task_dto.go +++ b/model/dto/model_gateway_task_dto.go @@ -1,6 +1,8 @@ package dto -import "github.com/gogf/gf/v2/frame/g" +import ( + "github.com/gogf/gf/v2/frame/g" +) // CreateTaskReq 创建异步任务 type CreateTaskReq struct { @@ -8,7 +10,6 @@ type CreateTaskReq struct { ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"` BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"` CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"` - InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用(如OSS/文件引用等)"` RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` BuildType int64 `json:"buildType" dc:"构建类型:1-提示词构建 2-节点构建"` @@ -68,9 +69,10 @@ type GetTaskBatchReq struct { } type GetTaskBatchItem struct { - TaskID string `json:"taskId" dc:"任务ID"` - State int `json:"state" dc:"任务状态"` - OssFile string `json:"ossFile" dc:"结果文件OSS地址"` + TaskID string `json:"taskId" dc:"任务ID"` + State int `json:"state" dc:"任务状态"` + OssFile string `json:"ossFile" dc:"结果文件OSS地址"` + TextResult map[string]any `json:"textResult" dc:"文本结果"` } type GetTaskBatchRes struct { @@ -83,8 +85,9 @@ type ListTaskReq struct { PageNum int `p:"pageNum" json:"pageNum" dc:"页码(默认1)"` PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数(默认10)"` ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"` + BizName string `p:"bizName" json:"bizName" dc:"业务名称"` TaskID string `p:"taskId" json:"taskId" dc:"任务ID(模糊匹配)"` - State *int `p:"state" json:"state" dc:"任务状态(0/1/2/3/4,可选)"` + State int `p:"state" json:"state" dc:"任务状态(0/1/2/3/4,可选)"` } type ListTaskRes struct { @@ -102,12 +105,3 @@ type RunWorkReq struct { type RunWorkRes struct { Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"` } - -// CleanWorkReq 手动触发 cleaner 执行一次(由上层定时任务调用) -type CleanWorkReq struct { - g.Meta `path:"/cleanWork" method:"post" tags:"任务管理" summary:"执行一次Cleaner" dc:"手动触发一次清理/重试(用于由上层定时任务控制)"` -} - -type CleanWorkRes struct { - Ok bool `json:"ok" dc:"是否执行成功"` -} diff --git a/model/dto/stat_dto.go b/model/dto/stat_dto.go deleted file mode 100644 index 321fa84..0000000 --- a/model/dto/stat_dto.go +++ /dev/null @@ -1,20 +0,0 @@ -package dto - -import "github.com/gogf/gf/v2/frame/g" - -// ListModelStatReq 统计列表 -type ListModelStatReq struct { - g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"` - PageNum int `p:"pageNum" json:"pageNum" dc:"页码(默认1)"` - PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数(默认10)"` - StartDay string `p:"startDay" json:"startDay" dc:"开始日期(YYYY-MM-DD,可选)"` - EndDay string `p:"endDay" json:"endDay" dc:"结束日期(YYYY-MM-DD,可选)"` - TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID(可选)"` - Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"` - ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"` -} - -type ListModelStatRes struct { - List any `json:"list" dc:"列表数据"` - Total int64 `json:"total" dc:"总数"` -} diff --git a/model/entity/asynch_task.go b/model/entity/asynch_task.go deleted file mode 100644 index 47a750e..0000000 --- a/model/entity/asynch_task.go +++ /dev/null @@ -1,89 +0,0 @@ -package entity - -import ( - "gitea.redpowerfuture.com/red-future/common/beans" - "github.com/gogf/gf/v2/os/gtime" -) - -type asynchTaskCol struct { - beans.SQLBaseCol - ModelName string - TaskID string - BizName string - CallbackURL string - ModelKey string - State string - OssFile string - FileType string - FileSize string - ErrorMsg string - StartedAt string - FinishedAt string - DurationSeconds string - ExpireAt string - RetryCount string - EnqueueAt string - Phase string - TmpFile string - InputRef string - RequestPayload string - TextResult string - EpicycleId string - ExpendTokens string -} - -var AsynchTaskCol = asynchTaskCol{ - SQLBaseCol: beans.DefSQLBaseCol, - ModelName: "model_name", - TaskID: "task_id", - BizName: "biz_name", - CallbackURL: "callback_url", - ModelKey: "model_key", - State: "state", - OssFile: "oss_file", - FileType: "file_type", - FileSize: "file_size", - ErrorMsg: "error_msg", - StartedAt: "started_at", - FinishedAt: "finished_at", - DurationSeconds: "duration_seconds", - ExpireAt: "expire_at", - RetryCount: "retry_count", - EnqueueAt: "enqueue_at", - Phase: "phase", - TmpFile: "tmp_file", - InputRef: "input_ref", - RequestPayload: "request_payload", - TextResult: "text_result", - EpicycleId: "epicycle_id", - ExpendTokens: "expend_tokens", -} - -// AsynchTask 异步任务 -type AsynchTask struct { - beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - TaskID string `orm:"task_id" json:"taskId"` - BizName string `orm:"biz_name" json:"bizName"` - CallbackURL string `orm:"callback_url" json:"callbackUrl"` - ModelKey string `orm:"model_key" json:"modelKey"` - State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载 - OssFile string `orm:"oss_file" json:"ossFile"` - FileType string `orm:"file_type" json:"fileType"` - FileSize int64 `orm:"file_size" json:"fileSize"` - ErrorMsg string `orm:"error_msg" json:"errorMsg"` - StartedAt *gtime.Time `orm:"started_at" json:"startedAt"` - FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"` - DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"` - ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间 - RetryCount int `orm:"retry_count" json:"retryCount"` - EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"` - Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段 - TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径 - InputRef string `orm:"input_ref" json:"inputRef"` - RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"` - TextResult map[string]any `orm:"text_result" json:"text"` - EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID(用于标识同一轮次的任务) - ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数 - RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"` -} diff --git a/model/entity/logs_model_op.go b/model/entity/logs_model_op.go deleted file mode 100644 index 76b7646..0000000 --- a/model/entity/logs_model_op.go +++ /dev/null @@ -1,57 +0,0 @@ -package entity - -import ( - "gitea.redpowerfuture.com/red-future/common/beans" -) - -type LogsModelPpCol struct { - beans.SQLBaseCol - IP string - UserAgent string - APIPath string - HttpMethod string - BizName string - ModelName string - TaskID string - OpType string - Success string - ErrorMsg string - CostMs string - RequestPayload string - ResponsePayload string -} - -var LogsModelOpCol = LogsModelPpCol{ - SQLBaseCol: beans.DefSQLBaseCol, - IP: "ip", - UserAgent: "user_agent", - APIPath: "api_path", - HttpMethod: "http_method", - BizName: "biz_name", - ModelName: "model_name", - TaskID: "task_id", - OpType: "op_type", - Success: "success", - ErrorMsg: "error_msg", - CostMs: "cost_ms", - RequestPayload: "request_payload", - ResponsePayload: "response_payload", -} - -// LogsModelOp 操作日志(创建任务等) -type LogsModelOp struct { - beans.SQLBaseDO `orm:",inline"` - IP string `orm:"ip" json:"ip"` - UserAgent string `orm:"user_agent" json:"userAgent"` - APIPath string `orm:"api_path" json:"apiPath"` - HttpMethod string `orm:"http_method" json:"httpMethod"` - BizName string `orm:"biz_name" json:"bizName"` - ModelName string `orm:"model_name" json:"modelName"` - TaskID string `orm:"task_id" json:"taskId"` - OpType string `orm:"op_type" json:"opType"` - Success int `orm:"success" json:"success"` - ErrorMsg string `orm:"error_msg" json:"errorMsg"` - CostMs int64 `orm:"cost_ms" json:"costMs"` - RequestPayload any `orm:"request_payload" json:"requestPayload"` - ResponsePayload any `orm:"response_payload" json:"responsePayload"` -} diff --git a/model/entity/logs_model_stat.go b/model/entity/logs_model_stat.go deleted file mode 100644 index e4583a9..0000000 --- a/model/entity/logs_model_stat.go +++ /dev/null @@ -1,38 +0,0 @@ -package entity - -import ( - "github.com/gogf/gf/v2/os/gtime" -) - -// LogsModelStatCol 字段常量 -type LogsModelStatCol struct { - Day string - TenantId string - Creator string - ModelName string - RequestCount string - CreatedAt string - UpdatedAt string -} - -var LogsModelStatCols = LogsModelStatCol{ - Day: "day", - TenantId: "tenant_id", - Creator: "creator", - ModelName: "model_name", - RequestCount: "request_count", - CreatedAt: "created_at", - UpdatedAt: "updated_at", -} - -// LogsModelStat 按天统计:某天/租户/创建人/模型的请求次数 -// 注:这里不走通用 SQLBaseDO,采用联合唯一键(day,tenant_id,creator,model_name)做 UPSERT 原子累加。 -type LogsModelStat struct { - Day *gtime.Time `orm:"day" json:"day"` // 日期(建议仅使用日期部分) - TenantId int64 `orm:"tenant_id" json:"tenantId"` // 租户ID - Creator string `orm:"creator" json:"creator"` // 创建人/操作人 - ModelName string `orm:"model_name" json:"modelName"` // 模型名称 - RequestCount int64 `orm:"request_count" json:"requestCount"` // 请求次数 - CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` // 创建时间 - UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` // 更新时间 -} diff --git a/model/entity/model_gateway_logs_op.go b/model/entity/model_gateway_logs_op.go new file mode 100644 index 0000000..1cf8a2d --- /dev/null +++ b/model/entity/model_gateway_logs_op.go @@ -0,0 +1,56 @@ +package entity + +import "gitea.redpowerfuture.com/red-future/common/beans" + +// ModelGatewayLogsOpCol 字段常量 +type modelGatewayLogsOpCol struct { + beans.SQLBaseCol + IP string + UserAgent string + APIPath string + HttpMethod string + BizName string + ModelName string + TaskID string + OpType string + Success string + ErrorMsg string + CostMs string + RequestPayload string + ResponsePayload string +} + +var ModelGatewayLogsOpCol = modelGatewayLogsOpCol{ + SQLBaseCol: beans.DefSQLBaseCol, + IP: "ip", + UserAgent: "user_agent", + APIPath: "api_path", + HttpMethod: "http_method", + BizName: "biz_name", + ModelName: "model_name", + TaskID: "task_id", + OpType: "op_type", + Success: "success", + ErrorMsg: "error_msg", + CostMs: "cost_ms", + RequestPayload: "request_payload", + ResponsePayload: "response_payload", +} + +// ModelGatewayLogsOp 操作日志 +type ModelGatewayLogsOp struct { + beans.SQLBaseDO `orm:",inline"` + IP string `orm:"ip" json:"ip"` + UserAgent string `orm:"user_agent" json:"userAgent"` + APIPath string `orm:"api_path" json:"apiPath"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + BizName string `orm:"biz_name" json:"bizName"` + ModelName string `orm:"model_name" json:"modelName"` + TaskID string `orm:"task_id" json:"taskId"` + OpType string `orm:"op_type" json:"opType"` + Success int `orm:"success" json:"success"` + ErrorMsg string `orm:"error_msg" json:"errorMsg"` + CostMs int64 `orm:"cost_ms" json:"costMs"` + RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"` + ResponsePayload map[string]any `orm:"response_payload" json:"responsePayload"` +} diff --git a/model/entity/model_gateway_logs_stat.go b/model/entity/model_gateway_logs_stat.go new file mode 100644 index 0000000..1277972 --- /dev/null +++ b/model/entity/model_gateway_logs_stat.go @@ -0,0 +1,35 @@ +package entity + +import "github.com/gogf/gf/v2/os/gtime" + +// ModelGatewayLogsStatCol 字段常量 +type ModelGatewayLogsStatCol struct { + Day string + TenantId string + Creator string + ModelName string + RequestCount string + CreatedAt string + UpdatedAt string +} + +var ModelGatewayLogsStatCols = ModelGatewayLogsStatCol{ + Day: "day", + TenantId: "tenant_id", + Creator: "creator", + ModelName: "model_name", + RequestCount: "request_count", + CreatedAt: "created_at", + UpdatedAt: "updated_at", +} + +// ModelGatewayLogsStat 按天统计 +type ModelGatewayLogsStat struct { + Day *gtime.Time `orm:"day" json:"day"` + TenantId uint64 `orm:"tenant_id" json:"tenantId"` + Creator string `orm:"creator" json:"creator"` + ModelName string `orm:"model_name" json:"modelName"` + RequestCount int64 `orm:"request_count" json:"requestCount"` + CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` + UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` +} diff --git a/model/entity/asynch_model.go b/model/entity/model_gateway_model.go similarity index 92% rename from model/entity/asynch_model.go rename to model/entity/model_gateway_model.go index e6e04dd..c912dd9 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/model_gateway_model.go @@ -2,7 +2,7 @@ package entity import "gitea.redpowerfuture.com/red-future/common/beans" -type asynchModelCol struct { +type modelGatewayModelCol struct { beans.SQLBaseCol ModelName string ModelType string @@ -33,9 +33,10 @@ type asynchModelCol struct { FirstFrame string LastFrame string CallbackUrl string + MaxTokens string } -var AsynchModelCol = asynchModelCol{ +var ModelGatewayModelCol = modelGatewayModelCol{ SQLBaseCol: beans.DefSQLBaseCol, ModelName: "model_name", ModelType: "model_type", @@ -66,10 +67,10 @@ var AsynchModelCol = asynchModelCol{ FirstFrame: "first_frame", LastFrame: "last_frame", CallbackUrl: "callback_url", + MaxTokens: "max_tokens", } -// AsynchModel 异步模型配置 -type AsynchModel struct { +type ModelGatewayModel struct { beans.SQLBaseDO `orm:",inline"` ModelName string `orm:"model_name" json:"modelName"` ModelType int `orm:"model_type" json:"modelType"` @@ -80,7 +81,7 @@ type AsynchModel struct { RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` ResponseBody string `orm:"response_body" json:"responseBody"` - ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` + ResponseTokenField string `orm:"response_token_field" json:"tokenField"` RequiredFields []string `orm:"required_fields" json:"requiredFields"` IsPrivate *int `orm:"is_private" json:"isPrivate"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` @@ -91,7 +92,7 @@ type AsynchModel struct { TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` RetryTimes int `orm:"retry_times" json:"retryTimes"` AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` - IsOwner *int `json:"isOwner" orm:"is_owner"` + IsOwner *int `orm:"is_owner" json:"isOwner"` OperatorName string `orm:"operator_name" json:"operatorName"` TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` @@ -100,4 +101,5 @@ type AsynchModel struct { FirstFrame string `orm:"first_frame" json:"firstFrame"` LastFrame string `orm:"last_frame" json:"lastFrame"` CallbackUrl string `orm:"callback_url" json:"callbackUrl"` + MaxTokens int `orm:"max_tokens" json:"maxTokens"` } diff --git a/model/entity/model_gateway_task.go b/model/entity/model_gateway_task.go new file mode 100644 index 0000000..baaa13e --- /dev/null +++ b/model/entity/model_gateway_task.go @@ -0,0 +1,76 @@ +package entity + +import ( + "gitea.redpowerfuture.com/red-future/common/beans" +) + +type modelGatewayTaskCol struct { + beans.SQLBaseCol + ModelName string + TaskID string + BizName string + CallbackURL string + State string + Phase string + ErrorMsg string + ResultFile string + TextResult string + ExpendTokens string + DurationSeconds string + RetryCount string + TmpFile string + RequestPayload string + EpicycleId string +} + +var ModelGatewayTaskCol = modelGatewayTaskCol{ + SQLBaseCol: beans.DefSQLBaseCol, + ModelName: "model_name", + TaskID: "task_id", + BizName: "biz_name", + CallbackURL: "callback_url", + State: "state", + Phase: "phase", + ErrorMsg: "error_msg", + ResultFile: "result_file", + TextResult: "text_result", + ExpendTokens: "expend_tokens", + DurationSeconds: "duration_seconds", + RetryCount: "retry_count", + TmpFile: "tmp_file", + RequestPayload: "request_payload", + EpicycleId: "epicycle_id", +} + +// ModelGatewayTask 模型网关任务 +type ModelGatewayTask struct { + beans.SQLBaseDO `orm:",inline"` + ModelName string `orm:"model_name" json:"modelName"` + TaskID string `orm:"task_id" json:"taskId"` + BizName string `orm:"biz_name" json:"bizName"` + CallbackURL string `orm:"callback_url" json:"callbackUrl"` + State int `orm:"state" json:"state"` + Phase int `orm:"phase" json:"phase"` + ErrorMsg string `orm:"error_msg" json:"errorMsg"` + ResultFile *ResultFile `orm:"result_file" json:"resultFile"` + TextResult map[string]any `orm:"text_result" json:"text"` + ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` + DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"` + RetryCount int `orm:"retry_count" json:"retryCount"` + TmpFile string `orm:"tmp_file" json:"tmpFile"` + RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"` + EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` +} + +// ResultFile OSS 结果文件 +type ResultFile struct { + OssFile string `json:"ossFile"` + FileType string `json:"fileType"` + FileSize int64 `json:"fileSize"` +} + +// RequestPayload 请求参数结构体 +type RequestPayload struct { + Headers map[string]string `json:"headers"` + Body map[string]any `json:"body"` +} diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go index 193070f..ac5af58 100644 --- a/service/gateway/gateway_http_service.go +++ b/service/gateway/gateway_http_service.go @@ -77,14 +77,14 @@ type CallbackPayload struct { } // TriggerCallback 任务的回调 -func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { +func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) { headers := util.ForwardHeaders(ctx) var resp struct{} payload := CallbackPayload{ TaskId: t.TaskID, State: t.State, - OssFile: t.OssFile, - FileType: t.FileType, + OssFile: t.ResultFile.OssFile, + FileType: t.ResultFile.FileType, Messages: t.TextResult, ErrorMsg: t.ErrorMsg, } @@ -111,7 +111,7 @@ type PromptsCallbackPayload struct { } // TriggerPromptsCallback 任务成功后的提示词回调 -func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { +func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) { callbackURL := "prompts-core/session/callback" headers := util.ForwardHeaders(ctx) var resp struct{} diff --git a/service/job/cleaner.go b/service/job/cleaner.go deleted file mode 100644 index d0117c4..0000000 --- a/service/job/cleaner.go +++ /dev/null @@ -1,102 +0,0 @@ -package job - -import ( - "context" - "model-gateway/model/dto" - "model-gateway/service/queue" - "os" - "time" - - "model-gateway/dao" - - "github.com/gogf/gf/v2/frame/g" -) - -var Cleaner = &cleaner{} - -type cleaner struct{} - -// RunOnce 由上层定时任务触发:执行一次清理/重试 -func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) { - // 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS) - expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200) - if err != nil { - g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err) - } else { - for _, t := range expired { - _ = os.Remove(t.TmpFile) - _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) - } - g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired)) - } - - // 2) 超时任务标失败 - list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200) - if err != nil { - g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err) - } else { - for _, t := range list { - t.ErrorMsg = "任务超时自动失败" - _, err = dao.Task.Update(ctx, t) - if err != nil { - g.Log().Errorf(ctx, "[清理] 标记任务失败: %v", err) - } - queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - } - g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list)) - } - - // 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾) - retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200) - if err != nil { - g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err) - } else { - for _, t := range retryable { - // 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试 - // 获取模型配置以得到 queue_limit / expected_seconds - m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) - if err != nil || m == nil { - continue - } - limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2) - if limit > 0 { - ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds) - if !ok { - continue - } - } - // retry_queue_max_seconds 控制失败重试的排队策略: - // - =0:失败重试插队到队首 - // - >0:当任务从创建到现在的排队时长 >= maxSeconds,则插队到队首;否则仍放到队尾 - now := time.Now() - enqueueAt := now - maxSeconds := t.RetryQueueMaxSeconds - if maxSeconds == 0 { - enqueueAt = now.Add(-100 * 365 * 24 * time.Hour) - } else if maxSeconds > 0 && t.CreatedAt != nil { - if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second { - enqueueAt = now.Add(-100 * 365 * 24 * time.Hour) - } - } - _ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt) - } - g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable)) - } - - // 4) 超过重试次数仍失败(state=3)的任务:硬删除 - exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200) - if err != nil { - g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err) - } else { - for _, t := range exhausted { - _ = os.Remove(t.TmpFile) - // 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等) - queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) - _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) - } - g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted)) - } - return &dto.CleanWorkRes{ - Ok: true, - }, nil -} diff --git a/service/model/model_service.go b/service/model/model_service.go index 0cce202..b28b9ff 100644 --- a/service/model/model_service.go +++ b/service/model/model_service.go @@ -18,7 +18,7 @@ import ( "github.com/gogf/gf/v2/util/gconv" ) -var Model = &modelService{} +var ModelGatewayModels = &modelService{} type modelService struct{} @@ -37,7 +37,7 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dt } // 3)入库 - id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) + id, err := dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req)) if err != nil { return nil, err } @@ -56,27 +56,27 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro req.IsOwner = gconv.PtrInt(1) if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { req.IsOwner = gconv.PtrInt(0) - _, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) + _, err := dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req)) return err } // 3)跨租户判断:超管的模型不允许直接修改,走插入新记录 - model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) if err != nil { return err } if model.TenantId == 1 { - _, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) + _, err = dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req)) return err } - _, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) + _, err = dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req)) return err } // Delete 删除模型 func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { - _, err := dao.Model.Delete(ctx, &entity.AsynchModel{ + _, err := dao.ModelGatewayModels.Delete(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) return err @@ -91,7 +91,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM if g.IsEmpty(req.ID) { req.Creator = user.UserName } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{ Id: req.ID, Creator: user.UserName, @@ -123,7 +123,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li req.Creator = user.UserName // 3)查询 - models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) + models, total, err := dao.ModelGatewayModels.GetByCreatorAndPlatform(ctx, req) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li // UpdateChatModel 设置会话模型 func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { // 1)校验新模型存在 - newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + newModel, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, }) if err != nil || newModel == nil { @@ -146,7 +146,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM if err != nil { return err } - currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + currentModel, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, IsChatModel: gconv.PtrInt(1), }) @@ -161,7 +161,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM return errors.New("当前模型为非推理模型,不能设置为会话模型") } if currentModel.Id != req.Id { - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, IsChatModel: gconv.PtrInt(0), }) @@ -171,7 +171,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM } } - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, IsChatModel: gconv.PtrInt(1), }) @@ -185,7 +185,7 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR if err != nil { return nil, err } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, IsChatModel: gconv.PtrInt(1), }) @@ -203,14 +203,14 @@ func (s *modelService) clearUserChatModel(ctx context.Context) error { if err != nil { return err } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, IsChatModel: gconv.PtrInt(1), }) if err != nil || model == nil { return nil } - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, IsChatModel: gconv.PtrInt(0), }) @@ -223,7 +223,7 @@ func (s *modelService) checkChatModelUnique(ctx context.Context) error { if err != nil { return err } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, IsChatModel: gconv.PtrInt(1), }) diff --git a/service/queue/auto_tune.go b/service/queue/auto_tune.go index 5a1df52..43f4dc1 100644 --- a/service/queue/auto_tune.go +++ b/service/queue/auto_tune.go @@ -43,14 +43,14 @@ func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, req.WindowSeconds = 3600 // 默认1小时 } // 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限) - var modelRows []*entity.AsynchModel + var modelRows []*entity.ModelGatewayModel if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). Where("deleted_at IS NULL"). - Where(entity.AsynchModelCol.Enabled, 1). + Where(entity.ModelGatewayModelCol.Enabled, 1). Scan(&modelRows); err != nil { return nil, err } - modelMap := make(map[string]*entity.AsynchModel) + modelMap := make(map[string]*entity.ModelGatewayModel) for _, m := range modelRows { if m == nil || m.ModelName == "" { continue diff --git a/service/stat/stat_service.go b/service/stat/stat_service.go index 16b4446..f5ce942 100644 --- a/service/stat/stat_service.go +++ b/service/stat/stat_service.go @@ -2,36 +2,31 @@ package stat import ( "context" + "model-gateway/model/entity" "model-gateway/dao" "model-gateway/model/dto" ) -type statService struct{} +var ModelGatewayLogsStat = &logsStatService{} -var Stat = &statService{} +type logsStatService struct{} -func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { - pageNum, pageSize := 1, 10 - if req != nil { - if req.PageNum > 0 { - pageNum = req.PageNum - } - if req.PageSize > 0 { - pageSize = req.PageSize - } +func (s *logsStatService) List(ctx context.Context, req *dto.ListModelStatReq) (*dto.ListModelStatRes, error) { + if req == nil { + req = &dto.ListModelStatReq{} } - startDay, endDay := "", "" - var tenantID *int64 - creator, modelName := "", "" - if req != nil { - startDay = req.StartDay - endDay = req.EndDay - tenantID = req.TenantID - creator = req.Creator - modelName = req.ModelName + if req.PageNum <= 0 { + req.PageNum = 1 } - list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName) + if req.PageSize <= 0 { + req.PageSize = 10 + } + + list, total, err := dao.ModelGatewayLogsStat.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayLogsStat{ + Creator: req.Creator, + ModelName: req.ModelName, + }) if err != nil { return nil, err } diff --git a/service/task/task_service.go b/service/task/task_service.go index 3616dc1..70d5a01 100644 --- a/service/task/task_service.go +++ b/service/task/task_service.go @@ -17,25 +17,27 @@ import ( "gitea.redpowerfuture.com/red-future/common/utils" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" "github.com/google/uuid" ) -var Task = &taskService{} +var ModelGatewayTask = &taskService{} type taskService struct{} // Create 创建任务 func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { - startAt := time.Now() - taskID := uuid.NewString() + var ( + startAt = time.Now() + taskID = uuid.NewString() + ) + // 1) 检查模型配置,并且获取模型 userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{ SQLBaseDO: beans.SQLBaseDO{ TenantId: userInfo.TenantId, Creator: userInfo.UserName, @@ -66,19 +68,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * // 异步调用:注入回调地址后提交,拿到 task_id 轮询 req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl) } - storedPayload := map[string]any{ - "headers": util.ParseHeadMsgHeaders(model.HeadMsg), - "body": req.RequestPayload, + requestPayload := entity.RequestPayload{ + Body: req.RequestPayload, + Headers: util.ParseHeadMsgHeaders(model.HeadMsg), } - _, err = dao.Task.Insert(ctx, &entity.AsynchTask{ + id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{ ModelName: req.ModelName, TaskID: taskID, - State: 0, + State: public.TaskStatusPending, BizName: req.BizName, CallbackURL: req.CallbackUrl, - ModelKey: model.ApiKey, - InputRef: req.InputRef, - RequestPayload: storedPayload, + RequestPayload: &requestPayload, EpicycleId: req.EpicycleId, }) if err != nil { // 入库失败:回滚闸门占位 @@ -97,7 +97,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * apiPath = r.URL.Path httpMethod = r.Method } - _, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{ + _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{ IP: ip, UserAgent: ua, APIPath: apiPath, @@ -109,20 +109,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * Success: 1, ErrorMsg: "", CostMs: time.Since(startAt).Milliseconds(), - RequestPayload: storedPayload, + RequestPayload: &requestPayload, ResponsePayload: gdb.Map{ "taskId": taskID, }, }) // 5) 获取任务信息 - task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) + task, err := dao.ModelGatewayTask.ClaimByID(ctx, id) if err != nil { return nil, err } - if task == nil { - return nil, err - } // 5) 创建成功后立即异步尝试执行当前任务 go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) @@ -130,10 +127,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * return &dto.CreateTaskRes{TaskID: taskID}, nil } +// GetResult 获取任务结果 +func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { + t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ + TaskID: taskID, + }) + if err != nil { + return nil, err + } + if t == nil { + return nil, errors.New("任务不存在") + } + return &dto.GetTaskResultRes{ + OssFile: t.ResultFile.OssFile, + State: t.State, + }, nil +} + +// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间 +func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { + if req == nil || len(req.TaskIDs) == 0 { + return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil + } + // 1) 先查当前租户下的任务列表 + list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs) + if err != nil { + return nil, err + } + + // 2) 对成功(state=2)的任务:标记为已下载(state=4) + for _, t := range list { + if t == nil { + continue + } + if t.State != public.BuildTypeNode { + continue + } + _ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id) + + // 为了本次返回一致性,内存里也更新 + t.State = public.TaskStatusDownloaded + } + + // 3) 组装返回 + items := make([]dto.GetTaskBatchItem, 0, len(list)) + for _, t := range list { + if t == nil { + continue + } + items = append(items, dto.GetTaskBatchItem{ + TaskID: t.TaskID, + State: t.State, + OssFile: t.ResultFile.OssFile, + TextResult: t.TextResult, + }) + } + return &dto.GetTaskBatchRes{List: items}, nil +} + +// List 获取任务列表 +func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) { + if req.PageNum <= 0 { + req.PageNum = 1 + } + if req.PageSize <= 0 { + req.PageSize = 10 + } + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{ + SQLBaseDO: beans.SQLBaseDO{ + Creator: user.UserName, + }, + ModelName: req.ModelName, + BizName: req.BizName, + State: req.State, + TaskID: req.TaskID, + }) + if err != nil { + return nil, err + } + return &dto.ListTaskRes{List: list, Total: total}, nil +} + +// ModelTaskCallback 模型异步任务的回调通知 func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) { g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status) // 1. 查本地任务 - task, err := dao.Task.Get(ctx, &entity.AsynchTask{ + task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{ TaskID: req.TaskID, }) if err != nil || task == nil { @@ -167,7 +250,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi } // 1. 查 state=1(执行中)的异步任务 - tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit) + tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit) if err != nil { return nil, err } @@ -176,7 +259,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi var results []dto.QueryTaskItem for _, t := range tasks { // 拿到模型配置 - model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) + model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) if err != nil || model == nil || model.QueryConfig == nil { continue } @@ -206,100 +289,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi Results: results, }, nil } - -// GetResult 获取任务结果 -func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { - t, err := dao.Task.Get(ctx, &entity.AsynchTask{ - TaskID: taskID, - }) - if err != nil { - return nil, err - } - if t == nil { - return nil, errors.New("任务不存在") - } - return &dto.GetTaskResultRes{ - OssFile: t.OssFile, - State: t.State, - }, nil -} - -// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间 -func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { - if req == nil || len(req.TaskIDs) == 0 { - return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil - } - // 1) 先查当前租户下的任务列表 - list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs) - if err != nil { - return nil, err - } - - // 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at - now := time.Now() - for _, t := range list { - if t == nil { - continue - } - if t.State != 2 { - continue - } - // 按模型配置决定保留时间 - m, err := dao.Model.Get(ctx, &entity.AsynchModel{ - ModelName: t.ModelName, - }) - if err != nil { - return nil, err - } - retainSeconds := 86400 - if m != nil && m.AutoCleanSeconds > 0 { - retainSeconds = m.AutoCleanSeconds - } - expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second)) - _ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt) - - // 为了本次返回一致性,内存里也更新 - t.State = 4 - t.ExpireAt = expireAt - } - - // 3) 组装返回 - items := make([]dto.GetTaskBatchItem, 0, len(list)) - for _, t := range list { - if t == nil { - continue - } - items = append(items, dto.GetTaskBatchItem{ - TaskID: t.TaskID, - State: t.State, - OssFile: t.OssFile, - }) - } - return &dto.GetTaskBatchRes{List: items}, nil -} - -// List 获取任务列表 -func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { - pageNum, pageSize := 1, 10 - if req != nil { - if req.PageNum > 0 { - pageNum = req.PageNum - } - if req.PageSize > 0 { - pageSize = req.PageSize - } - } - modelName := "" - taskID := "" - var state *int - if req != nil { - modelName = req.ModelName - taskID = req.TaskID - state = req.State - } - list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state) - if err != nil { - return nil, err - } - return &dto.ListTaskRes{List: list, Total: total}, nil -} diff --git a/service/task/worker.go b/service/task/worker.go index e3acd91..d8d0cab 100644 --- a/service/task/worker.go +++ b/service/task/worker.go @@ -24,7 +24,6 @@ import ( "gitea.redpowerfuture.com/red-future/common/beans" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) @@ -34,11 +33,13 @@ type asyncWorker struct { } // handleOne 执行一次完整的任务 -func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) { - body := util.GetModelBody(task.RequestPayload) // 核心请求参数 - maxRetry := model.RetryTimes // 重试次数 - startTime := time.Now() - +func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) { + var ( + body = task.RequestPayload.Body // 核心请求参数 + maxRetry = model.RetryTimes // 重试次数 + startTime = time.Now() + modelMessages = map[string]any{} + ) g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) // 1) 分布式并发控制 @@ -51,8 +52,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo return } if !acquired { + _, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ + SQLBaseDO: beans.SQLBaseDO{ + Id: task.Id, + }, + State: public.TaskStatusPending, + }) g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) - _ = w.rollbackToPending(ctx, task.Id) return } defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() @@ -65,24 +71,24 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo w.failTask(ctx, task, startTime, err.Error()) return } - body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) + modelMessages, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } case model.CallMode != nil && *model.CallMode == public.CallModeAsync: - body, err = w.callModel(ctx, task, model, body) + modelMessages, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } - body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg) + modelMessages, err = util.PullTaskResult(ctx, modelMessages, model.QueryConfig, model.HeadMsg) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return } default: - body, err = w.callModel(ctx, task, model, body) + modelMessages, err = w.callModel(ctx, task, model, body) if err != nil { w.failTask(ctx, task, startTime, err.Error()) return @@ -90,20 +96,20 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo } // 3) 保存临时文件 - tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile) + tmpPath, err := util.SaveTempFileByType(task.TaskID, modelMessages, task.TmpFile) if err == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _, err = dao.Task.Update(ctx, task) + _, err = dao.ModelGatewayTask.Update(ctx, task) if err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) } } // 4) 解析校验 + 响应映射(可重试,失败重新调模型) - body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime) + modelMessages, err = w.parseAndRetry(ctx, modelMessages, task, model, req, maxRetry, startTime) if err != nil { - task.TextResult = body + task.TextResult = modelMessages w.failTask(ctx, task, startTime, err.Error()) return } @@ -123,9 +129,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo if attempt == maxRetry { task.State = 3 task.ErrorMsg = err.Error() - task.FinishedAt = gtime.Now() task.Phase = 1 - _, err = dao.Task.Update(ctx, task) + _, err = dao.ModelGatewayTask.Update(ctx, task) if err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) } @@ -137,12 +142,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo // 6) 成功回调 task.State = 2 task.DurationSeconds = int64(time.Since(startTime).Seconds()) - task.OssFile = oss.FileAddressPrefix + oss.FileURL - task.FileType = oss.FileFormat - task.TextResult = body - task.FileSize = int64(oss.FileSize) - - if _, err = dao.Task.Update(ctx, task); err != nil { + task.ResultFile = &entity.ResultFile{ + OssFile: oss.FileAddressPrefix + oss.FileURL, + FileType: oss.FileFormat, + FileSize: int64(oss.FileSize), + } + task.TextResult = modelMessages + if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) return } @@ -161,7 +167,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo } // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) -func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) { +func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { var data []byte var err error @@ -173,8 +179,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa } if data == nil { - _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) - data, err = InvokeModel(ctx, model, body, task.ModelKey) + data, err = InvokeModel(ctx, model, body) if err != nil { return nil, err } @@ -182,7 +187,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _, err = dao.Task.Update(ctx, task) + _, err = dao.ModelGatewayTask.Update(ctx, task) if err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr) } @@ -201,7 +206,7 @@ type asyncResult struct { // asyncTaskChan 全局异步任务等待通道 var asyncTaskChan = sync.Map{} // taskID → chan asyncResult -func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { +func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) { // 1. 提交异步任务 body, err := w.callModel(ctx, task, model, body) if err != nil { @@ -246,7 +251,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) { // callModel 调用模型 + 检测文件类型 + 保存临时文件 // 返回: 解析后的响应体, error -func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { +func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) { var data []byte var err error @@ -261,8 +266,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo // 2) 没有可用数据,调用模型 if data == nil { - _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) - data, err = InvokeModel(ctx, model, body, task.ModelKey) + data, err = InvokeModel(ctx, model, body) if err != nil { return nil, err } @@ -273,7 +277,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo if tmpErr == nil && tmpPath != "" { task.TmpFile = tmpPath task.Phase = 1 - _, err = dao.Task.Update(ctx, task) + _, err = dao.ModelGatewayTask.Update(ctx, task) if err != nil { g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr) } @@ -297,7 +301,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo } // parseAndRetry 解析模型返回结果,并重试 -func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { +func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { for attempt := 0; attempt <= maxRetry; attempt++ { if attempt > 0 { g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) @@ -316,7 +320,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta // 2) 先存 token 到数据库,防止后续失败丢失 if _, ok := mapped[model.ResponseTokenField]; ok { task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField]) - _, err = dao.Task.Update(ctx, &entity.AsynchTask{ + _, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{ SQLBaseDO: beans.SQLBaseDO{Id: task.Id}, ExpendTokens: task.ExpendTokens, }) @@ -344,9 +348,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta } // 4) 重新调模型(直接调,不走缓存) - _ = dao.Task.IncRetryCountGlobal(ctx, task.Id) - reqBody := util.GetModelBody(task.RequestPayload) - rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey) + task.RetryCount++ + _, _ = dao.ModelGatewayTask.Update(ctx, task) + rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body) + if callErr != nil { g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) continue @@ -354,7 +359,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta // 5) 解析原始响应,覆盖 body 进入下一轮 var rawResp map[string]any - if err := json.Unmarshal(rawData, &rawResp); err != nil { + if err = json.Unmarshal(rawData, &rawResp); err != nil { g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) continue } @@ -366,18 +371,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta // InvokeModel 调用模型服务,返回二进制结果 // modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) -func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) { - // 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 +func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) { + // 1) 记录模型调用次数 + _ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName) + + // 2)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //mappedPayload := util.ReverseMap(model.RequestMapping, payload) - // 2)构建请求 URL 和超时 + // 3)构建请求 URL 和超时 baseURL := strings.TrimRight(model.BaseURL, "/") timeout := time.Duration(model.TimeoutSeconds) * time.Second client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) - // 3)构建 HTTP 请求 + // 4)构建 HTTP 请求 var req *http.Request switch method { case http.MethodGet: @@ -401,31 +409,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) } - // 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) + // 5)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { req.Header.Set(hk, hv) } - if modelKey != "" { - req.Header.Set("Authorization", "Bearer "+modelKey) + if model.ApiKey != "" { + req.Header.Set("Authorization", "Bearer "+model.ApiKey) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } - // 5)发送请求 + // 6)发送请求 resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() - // 6)读取响应体 + // 7)读取响应体 b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } - // 7)检查 HTTP 状态码 + // 8)检查 HTTP 状态码 if resp.StatusCode < 200 || resp.StatusCode >= 300 { msg := string(b) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) @@ -488,7 +496,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string // } // uploadOSS 从临时文件上传 OSS -func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { +func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.ModelGatewayTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) if err != nil { return nil, fmt.Errorf("读取临时文件失败: %w", err) @@ -498,19 +506,14 @@ func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gat } // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 -func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) { +func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) { t.State = 3 t.ErrorMsg = errMsg t.DurationSeconds = int64(time.Since(startTime).Seconds()) - _, err := dao.Task.Update(ctx, t) + _, err := dao.ModelGatewayTask.Update(ctx, t) if err != nil { g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err) } queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) } - -// rollbackToPending 恢复任务状态为 PENDING -func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { - return dao.Task.RollbackToPendingGlobal(ctx, id) -} diff --git a/update.sql b/update.sql index e1cdef1..f60414c 100644 --- a/update.sql +++ b/update.sql @@ -109,88 +109,61 @@ COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的 -- ========================= --- 2) asynch_task +-- model_gateway_task -- ========================= -CREATE TABLE IF NOT EXISTS asynch_task ( - -- 基础字段 - id BIGINT PRIMARY KEY, -- 主键ID(非自增) - tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID - creator VARCHAR(64) NOT NULL, -- 创建人 - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 - updater VARCHAR(64) NOT NULL, -- 更新人 - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 - deleted_at TIMESTAMP(6), -- 删除时间(软删) - - -- 业务字段 - model_name VARCHAR(128) NOT NULL, -- 模型名称 - task_id VARCHAR(64) NOT NULL, -- 任务ID(对外返回) - biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 业务名称(调用方模块/系统) - callback_url VARCHAR(512) DEFAULT '', -- 回调地址(可选,用于后续业务通知) - model_key VARCHAR(1024) DEFAULT '', -- 动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx - state SMALLINT NOT NULL DEFAULT 0, -- 0排队中/1执行中/2成功/3失败/4已下载 - oss_file VARCHAR(512) DEFAULT '', -- 结果文件OSS地址 - file_type VARCHAR(32) DEFAULT '', -- 文件类型(mp3/mp4/png/...) - file_size BIGINT NOT NULL DEFAULT 0, -- 文件大小(字节) - error_msg TEXT DEFAULT '', -- 错误信息 - started_at TIMESTAMP, -- 开始执行时间 - finished_at TIMESTAMP, -- 执行结束时间 - duration_seconds BIGINT NOT NULL DEFAULT 0, -- 耗时(秒):从创建到完成(成功/失败)整体耗时 - expire_at TIMESTAMP, -- state=4 后写入,用于清理 - retry_count INT NOT NULL DEFAULT 0, -- 已重试次数(不含首次) - enqueue_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 入队时间(用于排队顺序) - phase SMALLINT NOT NULL DEFAULT 0, -- 0模型阶段/1OSS阶段 - tmp_file TEXT DEFAULT '', -- 临时结果文件路径(phase=1 时仅重试 OSS 上传) - input_ref TEXT DEFAULT '', -- 输入引用(如OSS/业务资源ID等) - request_payload JSONB, -- 请求参数(可选) - text_result TEXT DEFAULT '', -- 文本类结果(可选,支持直接回调) - epicycle_id VARCHAR(64) DEFAULT '', -- 轮次ID - expend_tokens BIGINT NOT NULL DEFAULT 0 -- 消耗 token 数 +CREATE TABLE model_gateway_task ( + id int8 PRIMARY KEY, + tenant_id int8 NOT NULL DEFAULT 0, + creator varchar(64) NOT NULL, + created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + updater varchar(64) NOT NULL, + updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at timestamp(6), + model_name varchar(128) NOT NULL, + task_id varchar(64) NOT NULL, + biz_name varchar(128) NOT NULL DEFAULT '', + callback_url varchar(512) DEFAULT '', + state int2 NOT NULL DEFAULT 0, + retry_count int4 NOT NULL DEFAULT 0, + phase int2 NOT NULL DEFAULT 0, + tmp_file text DEFAULT '', + error_msg text DEFAULT '', + result_file jsonb NOT NULL DEFAULT '{}', + request_payload jsonb NOT NULL DEFAULT '{}', + text_result jsonb NOT NULL DEFAULT '{}', + expend_tokens int8 NOT NULL DEFAULT 0, + duration_seconds int8 NOT NULL DEFAULT 0, + epicycle_id varchar(64) NOT NULL DEFAULT '' ); -CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_task_tenant_task_id ON asynch_task(tenant_id, task_id); -CREATE INDEX IF NOT EXISTS idx_asynch_task_tenant_id ON asynch_task(tenant_id); -CREATE INDEX IF NOT EXISTS idx_asynch_task_model_name ON asynch_task(model_name); -CREATE INDEX IF NOT EXISTS idx_asynch_task_biz_name ON asynch_task(biz_name); -CREATE INDEX IF NOT EXISTS idx_asynch_task_model_key ON asynch_task(model_key); -CREATE INDEX IF NOT EXISTS idx_asynch_task_state ON asynch_task(state); -CREATE INDEX IF NOT EXISTS idx_asynch_task_enqueue_at ON asynch_task(enqueue_at); -CREATE INDEX IF NOT EXISTS idx_asynch_task_updated_at ON asynch_task(updated_at); -CREATE INDEX IF NOT EXISTS idx_asynch_task_expire_at ON asynch_task(expire_at); -CREATE INDEX IF NOT EXISTS idx_asynch_task_deleted_at ON asynch_task(deleted_at); -CREATE INDEX IF NOT EXISTS idx_asynch_task_epicycle_id ON asynch_task(epicycle_id); -CREATE INDEX IF NOT EXISTS idx_asynch_task_expend_tokens ON asynch_task(expend_tokens); +CREATE UNIQUE INDEX uk_model_gateway_task_tenant_creator_task_id ON model_gateway_task (tenant_id, creator, task_id); +CREATE INDEX idx_model_gateway_task_task_id ON model_gateway_task (task_id); +CREATE INDEX idx_model_gateway_task_state ON model_gateway_task (state); +CREATE INDEX idx_model_gateway_task_deleted_at ON model_gateway_task (deleted_at); -COMMENT ON TABLE asynch_task IS '异步任务表'; -COMMENT ON COLUMN asynch_task.id IS '主键ID(非自增)'; -COMMENT ON COLUMN asynch_task.tenant_id IS '租户ID'; -COMMENT ON COLUMN asynch_task.creator IS '创建人'; -COMMENT ON COLUMN asynch_task.created_at IS '创建时间'; -COMMENT ON COLUMN asynch_task.updater IS '更新人'; -COMMENT ON COLUMN asynch_task.updated_at IS '更新时间'; -COMMENT ON COLUMN asynch_task.deleted_at IS '删除时间(软删)'; -COMMENT ON COLUMN asynch_task.model_name IS '模型名称'; -COMMENT ON COLUMN asynch_task.task_id IS '任务ID(对外返回)'; -COMMENT ON COLUMN asynch_task.biz_name IS '业务名称(调用方模块/系统)'; -COMMENT ON COLUMN asynch_task.callback_url IS '回调地址(可选,用于后续业务通知)'; -COMMENT ON COLUMN asynch_task.model_key IS '动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx'; -COMMENT ON COLUMN asynch_task.state IS '0排队中/1执行中/2成功/3失败/4已下载'; -COMMENT ON COLUMN asynch_task.oss_file IS '结果文件OSS地址'; -COMMENT ON COLUMN asynch_task.file_type IS '文件类型(mp3/mp4/png/...)'; -COMMENT ON COLUMN asynch_task.file_size IS '文件大小(字节)'; -COMMENT ON COLUMN asynch_task.error_msg IS '错误信息'; -COMMENT ON COLUMN asynch_task.started_at IS '开始执行时间'; -COMMENT ON COLUMN asynch_task.finished_at IS '执行结束时间'; -COMMENT ON COLUMN asynch_task.duration_seconds IS '耗时(秒):从创建到完成(成功/失败)整体耗时'; -COMMENT ON COLUMN asynch_task.expire_at IS 'state=4 后写入,用于清理'; -COMMENT ON COLUMN asynch_task.retry_count IS '已重试次数(不含首次)'; -COMMENT ON COLUMN asynch_task.enqueue_at IS '入队时间(用于排队顺序)'; -COMMENT ON COLUMN asynch_task.phase IS '执行阶段 模型阶段/1OSS阶段(模型已成功,等待上传OSS)'; -COMMENT ON COLUMN asynch_task.tmp_file IS '临时结果文件路径(phase=1 时仅重试 OSS 上传)'; -COMMENT ON COLUMN asynch_task.input_ref IS '输入引用(如OSS/业务资源ID等)'; -COMMENT ON COLUMN asynch_task.request_payload IS '请求参数(可选,JSON)'; -COMMENT ON COLUMN asynch_task.text_result IS '文本类结果(可选,支持直接回调)'; -COMMENT ON COLUMN asynch_task.epicycle_id IS '轮次ID(用于标识同一轮次的任务)'; -COMMENT ON COLUMN asynch_task.expend_tokens IS '消耗 token 数'; +COMMENT ON TABLE model_gateway_task IS '模型网关任务表'; +COMMENT ON COLUMN model_gateway_task.id IS '主键ID'; +COMMENT ON COLUMN model_gateway_task.tenant_id IS '租户ID'; +COMMENT ON COLUMN model_gateway_task.creator IS '创建人'; +COMMENT ON COLUMN model_gateway_task.created_at IS '创建时间'; +COMMENT ON COLUMN model_gateway_task.updater IS '更新人'; +COMMENT ON COLUMN model_gateway_task.updated_at IS '更新时间'; +COMMENT ON COLUMN model_gateway_task.deleted_at IS '删除时间(软删)'; +COMMENT ON COLUMN model_gateway_task.model_name IS '模型名称'; +COMMENT ON COLUMN model_gateway_task.task_id IS '任务ID(对外返回)'; +COMMENT ON COLUMN model_gateway_task.biz_name IS '业务名称(调用方模块/系统)'; +COMMENT ON COLUMN model_gateway_task.callback_url IS '回调地址'; +COMMENT ON COLUMN model_gateway_task.state IS '0排队中/1执行中/2成功/3失败/4已下载'; +COMMENT ON COLUMN model_gateway_task.retry_count IS '已重试次数'; +COMMENT ON COLUMN model_gateway_task.phase IS '执行阶段:0模型阶段/1OSS阶段'; +COMMENT ON COLUMN model_gateway_task.tmp_file IS '临时结果文件路径'; +COMMENT ON COLUMN model_gateway_task.error_msg IS '错误信息'; +COMMENT ON COLUMN model_gateway_task.result_file IS '结果文件:{oss_file, file_type, file_size}'; +COMMENT ON COLUMN model_gateway_task.request_payload IS '请求参数(JSON)'; +COMMENT ON COLUMN model_gateway_task.text_result IS '文本类结果'; +COMMENT ON COLUMN model_gateway_task.expend_tokens IS '消耗token数'; +COMMENT ON COLUMN model_gateway_task.duration_seconds IS '耗时(秒)'; +COMMENT ON COLUMN model_gateway_task.epicycle_id IS '轮次ID';