Compare commits

3 Commits

34 changed files with 1000 additions and 1472 deletions

View File

@@ -13,7 +13,6 @@ import (
"strings"
"time"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -21,7 +20,7 @@ import (
)
// ParseAndValidate 解析并校验结果
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) {
func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody]
if !ok {
@@ -94,53 +93,6 @@ func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
}
}
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取 rounds
roundsRaw, ok := raw["rounds"]
if !ok {
return fmt.Errorf("缺少 rounds 字段")
}
rounds, ok := roundsRaw.([]any)
if !ok {
return fmt.Errorf("rounds 不是数组")
}
if len(rounds) == 0 {
return fmt.Errorf("rounds 数组为空")
}
// 2) 没有配置必填字段,跳过
if len(model.RequiredFields) == 0 {
return nil
}
// 3) 逐条校验
for i, r := range rounds {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
return nil
}
// validateRequiredFields 校验单个 round 对象的必选字段
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
for _, field := range requiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
}
}
return nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
@@ -198,16 +150,17 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
return mapped, nil
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
//
//// GetModelBody 获取数据库中保存的模型信息
//func GetModelBody(v map[string]any) map[string]any {
// if v == nil {
// return nil
// }
// if p, ok := v["body"]; ok {
// return gconv.Map(p)
// }
// return v
//}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
@@ -348,12 +301,3 @@ func replaceURLParams(url string, params map[string]any) string {
return s
})
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
return payload
}

View File

@@ -6,6 +6,14 @@ const (
CallModeStream = 2 // 流式调用
)
const (
TaskStatusPending = 0 // 排队中
TaskStatusRunning = 1 // 执行中
TaskStatusSuccess = 2 // 成功
TaskStatusFailed = 3 // 失败
TaskStatusDownloaded = 4 // 已下载
)
const (
BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建

View File

@@ -5,8 +5,8 @@ const (
)
const (
TableNameModel = "asynch_models" // 模型表
TableNameTask = "asynch_task" // 任务表
TableNameOpLog = "logs_model_op" // 操作日志表
TableNameStat = "logs_model_stat" // 按天统计表(请求次数)
TableNameModel = "model_gateway_models" // 模型表
TableNameTask = "model_gateway_task" // 任务表
TableNameOpLog = "model_gateway_logs_op" // 操作日志表
TableNameStat = "model_gateway_logs_stat" // 按天统计表
)

View File

@@ -7,12 +7,12 @@ import (
"model-gateway/model/dto"
)
type stat struct{}
// ModelGatewayLogsStat 统计控制器
var ModelGatewayLogsStat = new(stat)
// Stat 统计控制器
var Stat = new(stat)
type stat struct{}
// ListModelStat 统计列表
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
return statService.Stat.List(ctx, req)
return statService.ModelGatewayLogsStat.List(ctx, req)
}

View File

@@ -7,36 +7,36 @@ import (
"model-gateway/service/queue"
)
type model struct{}
// ModelGatewayModels 模型配置控制器
var ModelGatewayModels = new(model)
// Model 模型配置控制器
var Model = new(model)
type model struct{}
// CreateModel 添加配置
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
return modelService.Model.Create(ctx, req)
return modelService.ModelGatewayModels.Create(ctx, req)
}
// UpdateModel 更改配置
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
err = modelService.Model.Update(ctx, req)
err = modelService.ModelGatewayModels.Update(ctx, req)
return
}
// DeleteModel 删除配置
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
err = modelService.Model.Delete(ctx, req)
err = modelService.ModelGatewayModels.Delete(ctx, req)
return
}
// GetModel 获取配置详情
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
return modelService.Model.Get(ctx, req)
return modelService.ModelGatewayModels.Get(ctx, req)
}
// ListModel 配置列表
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
return modelService.Model.List(ctx, req)
return modelService.ModelGatewayModels.List(ctx, req)
}
// AutoTune 动态调参(由上层定时任务每小时触发一次)
@@ -56,11 +56,11 @@ func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res
// UpdateChatModel 更新是否为聊天模型
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
err = modelService.Model.UpdateChatModel(ctx, req)
err = modelService.ModelGatewayModels.UpdateChatModel(ctx, req)
return
}
// GetIsChatModel 获取当前会话模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
return modelService.Model.GetIsChatModel(ctx)
return modelService.ModelGatewayModels.GetIsChatModel(ctx)
}

View File

@@ -2,48 +2,42 @@ package controller
import (
"context"
"model-gateway/service/job"
taskService "model-gateway/service/task"
"model-gateway/model/dto"
)
type task struct{}
// ModelGatewayTask 任务控制器
var ModelGatewayTask = new(task)
// Task 任务控制器
var Task = new(task)
type task struct{}
// CreateTask 根据 modelName 创建异步任务,返回 taskId
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
return taskService.Task.Create(ctx, req)
return taskService.ModelGatewayTask.Create(ctx, req)
}
// ModelTaskCallback 接收模型异步任务的回调通知
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return taskService.Task.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.Task.QueryPendingTasks(ctx, req)
}
// GetTaskResult 获取任务结果(只返回 oss 地址 + state
// GetTaskResult 获取单条任务结果(返回 *dto.GetTaskResultRes
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
return taskService.Task.GetResult(ctx, req.TaskID)
return taskService.ModelGatewayTask.GetResult(ctx, req.TaskID)
}
// GetTaskBatch 批量查询任务(成功任务标记为已下载
// GetTaskBatch 批量查询任务(返回 *[]dto.GetTaskBatchItem
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
return taskService.Task.GetBatch(ctx, req)
return taskService.ModelGatewayTask.GetBatch(ctx, req)
}
// ListTask 任务列表分页查询
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
return taskService.Task.List(ctx, req)
return taskService.ModelGatewayTask.List(ctx, req)
}
// CleanWork 手动触发一次 cleaner由上层定时任务调用
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
return job.Cleaner.RunOnce(ctx)
// ModelTaskCallback 接收模型异步任务的回调通知 —— 待调整
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return taskService.ModelGatewayTask.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务 —— 待调整
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.ModelGatewayTask.QueryPendingTasks(ctx, req)
}

View File

@@ -0,0 +1,23 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var ModelGatewayLogsOp = &modelGatewayLogsOpDao{}
type modelGatewayLogsOpDao struct{}
// Insert 插入操作日志
func (d *modelGatewayLogsOpDao) Insert(ctx context.Context, req *entity.ModelGatewayLogsOp) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).Insert(req)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -0,0 +1,52 @@
package dao
import (
"context"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayLogsStat = &modelGatewayLogsStatDao{}
type modelGatewayLogsStatDao struct{}
// IncRequestCount 原子累加:按天+租户+创建人+模型 +1
func (d *modelGatewayLogsStatDao) IncRequestCount(ctx context.Context, day time.Time, tenantId uint64, creator, modelName string) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
Data(&entity.ModelGatewayLogsStat{
Day: gtime.New(day),
TenantId: tenantId,
Creator: creator,
ModelName: modelName,
RequestCount: 1,
}).
OnDuplicate("request_count", "request_count+1").
Insert()
return err
}
// List 分页查询统计
func (d *modelGatewayLogsStatDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayLogsStat) (list []*entity.ModelGatewayLogsStat, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
OmitEmpty().
Where(entity.ModelGatewayLogsStatCols.Creator, req.Creator).
WhereLike(entity.ModelGatewayLogsStatCols.ModelName, "%"+req.ModelName+"%").
OrderDesc(entity.ModelGatewayLogsStatCols.Day).
OrderDesc(entity.ModelGatewayLogsStatCols.RequestCount)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -9,71 +9,64 @@ import (
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelDao{}
var ModelGatewayModels = &modelGatewayModelsDao{}
type modelDao struct{}
type modelGatewayModelsDao struct{}
// Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
m := new(entity.AsynchModel)
err = gconv.Struct(req, &m)
func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return
return 0, err
}
return r.LastInsertId()
}
// Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Data(&req).
Where(entity.AsynchModelCol.Id, req.Id).
Data(req).
Where(entity.ModelGatewayModelCol.Id, req.Id).
Update()
if err != nil {
return
return 0, err
}
return r.RowsAffected()
}
// Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Id, req.Id).
Delete()
if err != nil {
return
return 0, err
}
return r.RowsAffected()
}
// Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
return nil, err
}
var m entity.ModelGatewayModel
err = r.Struct(&m)
return
return &m, err
}
//// Get 按ID获取带租户隔离只查当前租户
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
//func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// var whereCondition strings.Builder
// var queryParams []interface{}
// if !g.IsEmpty(req.Id) {
@@ -108,25 +101,25 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
// return
//}
// GetByAcrossTenant 按ID获取跨租户查所有租户
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// GetByAcrossTenant 跨租户查询
func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx).
OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Fields(fields).One()
if err != nil {
return
return nil, err
}
var m entity.ModelGatewayModel
err = r.Struct(&m)
return
return &m, err
}
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
@@ -186,7 +179,7 @@ WHERE deleted_at IS NULL
}
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName,
@@ -197,7 +190,7 @@ func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64,
if r.IsEmpty() {
return nil, nil
}
var list []*entity.AsynchModel
var list []*entity.ModelGatewayModel
if err := r.Structs(&list); err != nil {
return nil, err
}

View File

@@ -0,0 +1,159 @@
package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayTask = &modelGatewayTaskDao{}
type modelGatewayTaskDao struct{}
// Insert 插入
func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) {
m := new(entity.ModelGatewayTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新按ID
func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(req).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取按TaskID 或 ID
func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// List 分页查询
func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.Creator, req.Creator).
Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%").
Where(entity.ModelGatewayTaskCol.BizName, req.BizName).
Where(entity.ModelGatewayTaskCol.State, req.State).
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
OrderDesc(entity.ModelGatewayTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// Delete 删除软删按ID
func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Delete()
if err != nil {
return
}
return r.RowsAffected()
}
// ListByTaskIDs 批量查询
func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// MarkDownloadedByID 标记已下载
func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, 2).
Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}).
Update()
return err
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) {
var tasks []*entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.State, 1).
Limit(limit).
Scan(&tasks)
return tasks, err
}
// ======================== 事务抢占 ========================
// ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil {
return nil, err
}
return &task, nil
}

View File

@@ -1,30 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
type opLogDao struct{}
var OpLog = &opLogDao{}
// Insert 插入
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
m := new(entity.LogsModelOp)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
Insert(m)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -1,60 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
)
type statDao struct{}
var Stat = &statDao{}
// IncRequestCount 原子累加(支持分布式/多协程):按天+租户+创建人+模型 +1
func (d *statDao) IncRequestCount(ctx context.Context, day time.Time, tenantId int64, creator, modelName string) error {
sql := fmt.Sprintf(`
INSERT INTO %s(day, tenant_id, creator, model_name, request_count, created_at, updated_at)
VALUES(?, ?, ?, ?, 1, NOW(), NOW())
ON CONFLICT (day, tenant_id, creator, model_name)
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
public.TableNameStat, public.TableNameStat,
)
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
return err
}
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) {
m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1")
if startDay != "" {
m = m.Where("day >= ?", startDay)
}
if endDay != "" {
m = m.Where("day <= ?", endDay)
}
if tenantId != nil {
m = m.Where("tenant_id = ?", *tenantId)
}
if creator != "" {
m = m.WhereLike("creator", "%"+creator+"%")
}
if modelName != "" {
m = m.WhereLike("model_name", "%"+modelName+"%")
}
m = m.OrderDesc("day").OrderDesc("request_count")
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -1,124 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = &taskDao{}
type taskDao struct{}
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(&req).
Where(entity.AsynchTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&m)
return
}
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
data := gdb.Map{
entity.AsynchTaskCol.State: 4,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
Update()
return err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
if modelNameLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
}
if taskIDLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
}
if state != nil {
m = m.Where(entity.AsynchTaskCol.State, *state)
}
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where("state", 1).
Where("deleted_at IS NULL").
Limit(limit).
Scan(&tasks)
return tasks, err
}

View File

@@ -1,214 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ======================== 查询辅助 ========================
// taskColumns 查询用的公共字段
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 事务抢占 ========================
// claimTasks 事务内抢占任务并更新 state=1
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
if batchSize <= 0 {
batchSize = 1
}
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
if err != nil {
return err
}
}
return nil
})
return tasks, err
}
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
if taskID == "" {
return nil, nil
}
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
}
// ======================== 更新辅助 ========================
func execSQL(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
return err
}
// updateTask 通用更新
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
return err
}
// UpdateSuccessGlobal 更新任务成功
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 2,
OssFile: t.OssFile,
FileType: t.FileType,
TextResult: t.TextResult,
FileSize: t.FileSize,
ErrorMsg: "",
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
ExpendTokens: t.ExpendTokens,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedGlobal 模型调用失败
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 3,
ErrorMsg: t.ErrorMsg,
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
TextResult: t.TextResult,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedKeepTmpGlobal OSS 上传失败
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
}
// UpdateTmpAfterModelGlobal 写临时文件
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
}
// RollbackToPendingGlobal 回滚
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
}
// IncRetryCountGlobal 重试计数+1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
}
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
}
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
}
// ListFailedRetryableGlobal
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListFailedExhaustedGlobal
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListTimeoutTasksGlobal
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// HardDeleteByIDGlobal
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
}
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
err = r.Structs(&list)
return list, err
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}
// UpdateColumns 更新指定字段(结构体版)
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}

29
main.go
View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/job"
"model-gateway/service/task"
"os"
"os/signal"
@@ -27,9 +26,9 @@ func main() {
// 注册路由
http.RouteRegister([]interface{}{
controller.Model,
controller.Task,
controller.Stat,
controller.ModelGatewayModels,
controller.ModelGatewayTask,
controller.ModelGatewayLogsStat,
})
// 本地调试:可选自动触发 worker/cleaner由配置文件控制
@@ -47,26 +46,6 @@ func main() {
}
func startAutoRunner(ctx context.Context) {
// cleaner
if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int()
if interval <= 0 {
interval = 30
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_, _ = job.Cleaner.RunOnce(ctx)
}
}
}()
}
// queryPending
if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int()
@@ -79,7 +58,7 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
if _, err := task.ModelGatewayTask.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err)
}
}

View File

@@ -0,0 +1,20 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -103,7 +103,7 @@ type GetModelReq struct {
}
type GetModelRes struct {
Model *entity.AsynchModel `json:"model" dc:"模型配置详情"`
Model *entity.ModelGatewayModel `json:"model" dc:"模型配置详情"`
}
// ListModelReq 配置列表

View File

@@ -1,6 +1,8 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
import (
"github.com/gogf/gf/v2/frame/g"
)
// CreateTaskReq 创建异步任务
type CreateTaskReq struct {
@@ -8,7 +10,6 @@ type CreateTaskReq struct {
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用如OSS/文件引用等)"`
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"`
@@ -67,24 +68,26 @@ type GetTaskBatchReq struct {
TaskIDs []string `p:"taskIds" json:"taskIds" v:"required#taskIds不能为空" dc:"任务ID列表"`
}
type GetTaskBatchItem struct {
TaskID string `json:"taskId" dc:"任务ID"`
State int `json:"state" dc:"任务状态"`
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
}
type GetTaskBatchRes struct {
List []GetTaskBatchItem `json:"list" dc:"任务列表"`
}
type GetTaskBatchItem struct {
TaskID string `json:"taskId" dc:"任务ID"`
State int `json:"state" dc:"任务状态"`
OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
TextResult map[string]any `json:"textResult" dc:"文本结果"`
}
// ListTaskReq 任务列表分页查询
type ListTaskReq struct {
g.Meta `path:"/listTask" method:"get" tags:"任务管理" summary:"任务列表" dc:"分页查询任务列表,支持按状态/模型名称/task_id过滤"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称"`
TaskID string `p:"taskId" json:"taskId" dc:"任务ID模糊匹配"`
State *int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"`
State int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"`
}
type ListTaskRes struct {
@@ -102,12 +105,3 @@ type RunWorkReq struct {
type RunWorkRes struct {
Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"`
}
// CleanWorkReq 手动触发 cleaner 执行一次(由上层定时任务调用)
type CleanWorkReq struct {
g.Meta `path:"/cleanWork" method:"post" tags:"任务管理" summary:"执行一次Cleaner" dc:"手动触发一次清理/重试(用于由上层定时任务控制)"`
}
type CleanWorkRes struct {
Ok bool `json:"ok" dc:"是否执行成功"`
}

View File

@@ -1,20 +0,0 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -1,89 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/os/gtime"
)
type asynchTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
ModelKey string
State string
OssFile string
FileType string
FileSize string
ErrorMsg string
StartedAt string
FinishedAt string
DurationSeconds string
ExpireAt string
RetryCount string
EnqueueAt string
Phase string
TmpFile string
InputRef string
RequestPayload string
TextResult string
EpicycleId string
ExpendTokens string
}
var AsynchTaskCol = asynchTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
ModelKey: "model_key",
State: "state",
OssFile: "oss_file",
FileType: "file_type",
FileSize: "file_size",
ErrorMsg: "error_msg",
StartedAt: "started_at",
FinishedAt: "finished_at",
DurationSeconds: "duration_seconds",
ExpireAt: "expire_at",
RetryCount: "retry_count",
EnqueueAt: "enqueue_at",
Phase: "phase",
TmpFile: "tmp_file",
InputRef: "input_ref",
RequestPayload: "request_payload",
TextResult: "text_result",
EpicycleId: "epicycle_id",
ExpendTokens: "expend_tokens",
}
// AsynchTask 异步任务
type AsynchTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
ModelKey string `orm:"model_key" json:"modelKey"`
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
FileSize int64 `orm:"file_size" json:"fileSize"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
RetryCount int `orm:"retry_count" json:"retryCount"`
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
InputRef string `orm:"input_ref" json:"inputRef"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
TextResult map[string]any `orm:"text_result" json:"text"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID用于标识同一轮次的任务
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
}

View File

@@ -1,57 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type LogsModelPpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var LogsModelOpCol = LogsModelPpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// LogsModelOp 操作日志(创建任务等)
type LogsModelOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
ResponsePayload any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -1,38 +0,0 @@
package entity
import (
"github.com/gogf/gf/v2/os/gtime"
)
// LogsModelStatCol 字段常量
type LogsModelStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var LogsModelStatCols = LogsModelStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// LogsModelStat 按天统计:某天/租户/创建人/模型的请求次数
// 注:这里不走通用 SQLBaseDO采用联合唯一键day,tenant_id,creator,model_name做 UPSERT 原子累加。
type LogsModelStat struct {
Day *gtime.Time `orm:"day" json:"day"` // 日期(建议仅使用日期部分)
TenantId int64 `orm:"tenant_id" json:"tenantId"` // 租户ID
Creator string `orm:"creator" json:"creator"` // 创建人/操作人
ModelName string `orm:"model_name" json:"modelName"` // 模型名称
RequestCount int64 `orm:"request_count" json:"requestCount"` // 请求次数
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` // 创建时间
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` // 更新时间
}

View File

@@ -0,0 +1,56 @@
package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
// ModelGatewayLogsOpCol 字段常量
type modelGatewayLogsOpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var ModelGatewayLogsOpCol = modelGatewayLogsOpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// ModelGatewayLogsOp 操作日志
type ModelGatewayLogsOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
ResponsePayload map[string]any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -0,0 +1,35 @@
package entity
import "github.com/gogf/gf/v2/os/gtime"
// ModelGatewayLogsStatCol 字段常量
type ModelGatewayLogsStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var ModelGatewayLogsStatCols = ModelGatewayLogsStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// ModelGatewayLogsStat 按天统计
type ModelGatewayLogsStat struct {
Day *gtime.Time `orm:"day" json:"day"`
TenantId uint64 `orm:"tenant_id" json:"tenantId"`
Creator string `orm:"creator" json:"creator"`
ModelName string `orm:"model_name" json:"modelName"`
RequestCount int64 `orm:"request_count" json:"requestCount"`
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"`
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"`
}

View File

@@ -2,7 +2,7 @@ package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
type asynchModelCol struct {
type modelGatewayModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
@@ -32,10 +32,10 @@ type asynchModelCol struct {
StreamConfig string
FirstFrame string
LastFrame string
CallbackUrl string
MaxTokens string
}
var AsynchModelCol = asynchModelCol{
var ModelGatewayModelCol = modelGatewayModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
@@ -65,11 +65,10 @@ var AsynchModelCol = asynchModelCol{
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
CallbackUrl: "callback_url",
MaxTokens: "max_tokens",
}
// AsynchModel 异步模型配置
type AsynchModel struct {
type ModelGatewayModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
@@ -80,7 +79,7 @@ type AsynchModel struct {
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
ResponseTokenField string `orm:"response_token_field" json:"tokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
@@ -91,7 +90,7 @@ type AsynchModel struct {
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
IsOwner *int `orm:"is_owner" json:"isOwner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
@@ -99,5 +98,5 @@ type AsynchModel struct {
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
}

View File

@@ -0,0 +1,76 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type modelGatewayTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
State string
Phase string
ErrorMsg string
ResultFile string
TextResult string
ExpendTokens string
DurationSeconds string
RetryCount string
TmpFile string
RequestPayload string
EpicycleId string
}
var ModelGatewayTaskCol = modelGatewayTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
State: "state",
Phase: "phase",
ErrorMsg: "error_msg",
ResultFile: "result_file",
TextResult: "text_result",
ExpendTokens: "expend_tokens",
DurationSeconds: "duration_seconds",
RetryCount: "retry_count",
TmpFile: "tmp_file",
RequestPayload: "request_payload",
EpicycleId: "epicycle_id",
}
// ModelGatewayTask 模型网关任务
type ModelGatewayTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
State int `orm:"state" json:"state"`
Phase int `orm:"phase" json:"phase"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
ResultFile *ResultFile `orm:"result_file" json:"resultFile"`
TextResult map[string]any `orm:"text_result" json:"text"`
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
RetryCount int `orm:"retry_count" json:"retryCount"`
TmpFile string `orm:"tmp_file" json:"tmpFile"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"`
}
// ResultFile OSS 结果文件
type ResultFile struct {
OssFile string `json:"ossFile"`
FileType string `json:"fileType"`
FileSize int64 `json:"fileSize"`
}
// RequestPayload 请求参数结构体
type RequestPayload struct {
Headers map[string]string `json:"headers"`
Body map[string]any `json:"body"`
}

View File

@@ -24,6 +24,7 @@ type UploadFileResponse struct {
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
// UploadByTask 通过任务上传文件
func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) {
// multipart
body := &bytes.Buffer{}
@@ -68,24 +69,22 @@ func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *Upload
// CallbackPayload 回调请求体
type CallbackPayload struct {
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
Messages map[string]any `json:"messages"`
ErrorMsg string `json:"error_msg"`
TaskId string `json:"task_id"`
State int `json:"state"`
OssFile string `json:"oss_file"`
FileType string `json:"file_type"`
ErrorMsg string `json:"error_msg"`
}
// TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx)
var resp struct{}
payload := CallbackPayload{
TaskId: t.TaskID,
State: t.State,
OssFile: t.OssFile,
FileType: t.FileType,
Messages: t.TextResult,
OssFile: t.ResultFile.OssFile,
FileType: t.ResultFile.FileType,
ErrorMsg: t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
@@ -111,7 +110,7 @@ type PromptsCallbackPayload struct {
}
// TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx)
var resp struct{}

View File

@@ -1,99 +0,0 @@
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
"model-gateway/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
// 获取模型配置以得到 queue_limit / expected_seconds
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || m == nil {
continue
}
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
if limit > 0 {
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
if !ok {
continue
}
}
// retry_queue_max_seconds 控制失败重试的排队策略:
// - =0失败重试插队到队首
// - >0当任务从创建到现在的排队时长 >= maxSeconds则插队到队首否则仍放到队尾
now := time.Now()
enqueueAt := now
maxSeconds := t.RetryQueueMaxSeconds
if maxSeconds == 0 {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
} else if maxSeconds > 0 && t.CreatedAt != nil {
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
}
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -18,7 +18,7 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
var ModelGatewayModels = &modelService{}
type modelService struct{}
@@ -37,7 +37,7 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dt
}
// 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
id, err := dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
if err != nil {
return nil, err
}
@@ -56,27 +56,27 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err := dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
// 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err = dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
_, err = dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err
}
// Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
_, err := dao.ModelGatewayModels.Delete(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
@@ -91,7 +91,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
if g.IsEmpty(req.ID) {
req.Creator = user.UserName
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{
Id: req.ID,
Creator: user.UserName,
@@ -123,7 +123,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
req.Creator = user.UserName
// 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
models, total, err := dao.ModelGatewayModels.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return nil, err
}
@@ -134,7 +134,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
// UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
newModel, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil || newModel == nil {
@@ -146,7 +146,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
if err != nil {
return err
}
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
currentModel, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
@@ -161,7 +161,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
@@ -171,7 +171,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
}
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
@@ -185,7 +185,7 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
@@ -203,14 +203,14 @@ func (s *modelService) clearUserChatModel(ctx context.Context) error {
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})
if err != nil || model == nil {
return nil
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
_, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
@@ -223,7 +223,7 @@ func (s *modelService) checkChatModelUnique(ctx context.Context) error {
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1),
})

View File

@@ -43,14 +43,14 @@ func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes,
req.WindowSeconds = 3600 // 默认1小时
}
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel
var modelRows []*entity.ModelGatewayModel
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where("deleted_at IS NULL").
Where(entity.AsynchModelCol.Enabled, 1).
Where(entity.ModelGatewayModelCol.Enabled, 1).
Scan(&modelRows); err != nil {
return nil, err
}
modelMap := make(map[string]*entity.AsynchModel)
modelMap := make(map[string]*entity.ModelGatewayModel)
for _, m := range modelRows {
if m == nil || m.ModelName == "" {
continue

View File

@@ -2,36 +2,31 @@ package stat
import (
"context"
"model-gateway/model/entity"
"model-gateway/dao"
"model-gateway/model/dto"
)
type statService struct{}
var ModelGatewayLogsStat = &logsStatService{}
var Stat = &statService{}
type logsStatService struct{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
func (s *logsStatService) List(ctx context.Context, req *dto.ListModelStatReq) (*dto.ListModelStatRes, error) {
if req == nil {
req = &dto.ListModelStatReq{}
}
startDay, endDay := "", ""
var tenantID *int64
creator, modelName := "", ""
if req != nil {
startDay = req.StartDay
endDay = req.EndDay
tenantID = req.TenantID
creator = req.Creator
modelName = req.ModelName
if req.PageNum <= 0 {
req.PageNum = 1
}
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
if req.PageSize <= 0 {
req.PageSize = 10
}
list, total, err := dao.ModelGatewayLogsStat.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayLogsStat{
Creator: req.Creator,
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}

View File

@@ -17,25 +17,24 @@ import (
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid"
)
var Task = &taskService{}
var ModelGatewayTask = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString()
// 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId,
Creator: userInfo.UserName,
@@ -62,23 +61,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}
// 3) 插入任务记录
if model.CallMode != nil && *model.CallMode == public.CallModeAsync {
// 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
requestPayload := entity.RequestPayload{
Body: req.RequestPayload,
Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
}
storedPayload := map[string]any{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
State: public.TaskStatusPending,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
RequestPayload: &requestPayload,
EpicycleId: req.EpicycleId,
})
if err != nil { // 入库失败:回滚闸门占位
@@ -97,7 +90,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
_, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
@@ -107,22 +100,18 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
TaskID: taskID,
OpType: "createTask",
Success: 1,
ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload,
CostMs: time.Since(time.Now()).Milliseconds(),
RequestPayload: &requestPayload,
ResponsePayload: gdb.Map{
"taskId": taskID,
},
})
// 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil {
return nil, err
}
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
@@ -130,10 +119,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.ResultFile.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)
for _, t := range list {
if t == nil {
continue
}
if t.State != public.BuildTypeNode {
continue
}
_ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id)
// 为了本次返回一致性,内存里也更新
t.State = public.TaskStatusDownloaded
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.ResultFile.OssFile,
TextResult: t.TextResult,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) {
if req.PageNum <= 0 {
req.PageNum = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
ModelName: req.ModelName,
BizName: req.BizName,
State: req.State,
TaskID: req.TaskID,
})
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}
// ModelTaskCallback 模型异步任务的回调通知
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{
task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: req.TaskID,
})
if err != nil || task == nil {
@@ -167,7 +242,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
}
// 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit)
tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit)
if err != nil {
return nil, err
}
@@ -176,7 +251,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
var results []dto.QueryTaskItem
for _, t := range tasks {
// 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil {
continue
}
@@ -206,100 +281,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
Results: results,
}, nil
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

View File

@@ -21,6 +21,7 @@ import (
"model-gateway/service/gateway"
"model-gateway/service/queue"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -32,14 +33,19 @@ type asyncWorker struct {
}
// handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) // 核心请求参数
maxRetry := model.RetryTimes // 重试次数
startTime := time.Now()
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
var (
body = task.RequestPayload.Body
maxRetry = model.RetryTimes
startTime = time.Now()
result map[string]any
err error
)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================
// 1) 分布式并发控制
// ============================================
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
@@ -49,88 +55,93 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
return
}
if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return
}
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// ============================================
// 2) 调用模型
// ============================================
switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err := w.callModelStream(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if streamErr != nil {
w.failTask(ctx, task, startTime, streamErr.Error())
return
}
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
result, err = w.callModel(ctx, task, model, body)
if err == nil {
result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
}
default:
body, err = w.callModel(ctx, task, model, body)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
}
result, err = w.callModel(ctx, task, model, body)
}
// 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile)
if err == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
// 4) 解析校验 + 响应映射(可重试,失败重新调模型)
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = body
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 3) 缓存临时文件
// ============================================
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
}
// ============================================
// 4) 解析校验 + 响应映射(可重试)
// ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil {
task.TextResult = result
w.failTask(ctx, task, startTime, err.Error())
return
}
// ============================================
// 5) 上传 OSS可重试
// ============================================
var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
}
oss, err = w.uploadOSS(ctx, task)
oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil {
break
}
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, err)
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error())
task.State = public.TaskStatusFailed
task.ErrorMsg = err.Error()
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return
}
}
// 6) 成功回调
task.State = 2
// ============================================
// 6) 成功收尾
// ============================================
task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL
task.FileType = oss.FileFormat
task.TextResult = body
task.FileSize = int64(oss.FileSize)
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil {
task.ResultFile = &entity.ResultFile{
OssFile: oss.FileAddressPrefix + oss.FileURL,
FileType: oss.FileFormat,
FileSize: int64(oss.FileSize),
}
task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return
}
@@ -141,15 +152,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
}
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL)
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat)
// 7) 删除临时文件
_ = os.Remove(task.TmpFile)
}
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) {
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
var data []byte
var err error
@@ -161,8 +171,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
data, err = InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
@@ -170,7 +179,10 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
}
}
@@ -186,7 +198,7 @@ type asyncResult struct {
// asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body)
if err != nil {
@@ -231,7 +243,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
// callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
var data []byte
var err error
@@ -246,8 +258,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
// 2) 没有可用数据,调用模型
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
data, err = InvokeModel(ctx, model, body)
if err != nil {
return nil, err
}
@@ -258,7 +269,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
}
}
@@ -279,7 +293,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
}
// parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -296,10 +310,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
}
// 2) 先存 token 到数据库,防止后续失败丢失
if tokens, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(tokens)
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]),
if _, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
_, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
})
}
@@ -325,9 +340,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
}
// 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id)
reqBody := util.GetModelBody(task.RequestPayload)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey)
task.RetryCount++
_, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue
@@ -335,7 +351,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil {
if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue
}
@@ -347,18 +363,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) {
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2)构建请求 URL 和超时
// 3)构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3)构建 HTTP 请求
// 4)构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
@@ -382,31 +401,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 4)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
// 5)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
if modelKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey)
if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+model.ApiKey)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 5)发送请求
// 6)发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 6)读取响应体
// 7)读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 7)检查 HTTP 状态码
// 8)检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
@@ -468,27 +487,15 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
// return mappedResponse, nil
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) {
func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3
t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds())
_ = dao.Task.UpdateFailedGlobal(ctx, t)
_, err := dao.ModelGatewayTask.Update(ctx, t)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}

View File

@@ -1,285 +1,233 @@
-- model-asynch 核心表(pgsql)
-- 1) asynch_models模型配置
-- 2) asynch_task异步任务
-- 3) logs_model_op操作日志(统计用)
-- 4) logs_model_stat按天模型请求统计(限流/监控用)
-- =========================
-- 1) asynch_models
-- model_gateway_models
-- =========================
CREATE TABLE IF NOT EXISTS asynch_models (
-- ========== 基础字段 ==========
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6),
-- ========== 模型标识 ==========
model_name VARCHAR(128) NOT NULL,
model_type SMALLINT NOT NULL DEFAULT 0,
operator_name VARCHAR(64) NOT NULL DEFAULT '',
-- ========== 请求配置 ==========
base_url VARCHAR(256) NOT NULL,
http_method VARCHAR(8) NOT NULL DEFAULT 'POST',
head_msg JSONB NOT NULL DEFAULT '{}'::jsonb,
api_key VARCHAR(256) NOT NULL DEFAULT '',
-- ========== 状态开关 ==========
is_private SMALLINT NOT NULL DEFAULT 0,
enabled SMALLINT NOT NULL DEFAULT 1,
is_chat_model SMALLINT NOT NULL DEFAULT 0,
is_async SMALLINT NOT NULL DEFAULT 0,
is_stream SMALLINT NOT NULL DEFAULT 0,
is_owner SMALLINT NOT NULL DEFAULT 99,
-- ========== 配置相关 ==========
form_json JSONB NOT NULL DEFAULT '{}'::jsonb,
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
response_body JSONB NOT NULL DEFAULT '{}'::jsonb,
token_config JSONB NOT NULL DEFAULT '{}'::jsonb,
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
query_config JSONB NOT NULL DEFAULT '{}'::jsonb,
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb,
first_frame VARCHAR(128) NOT NULL DEFAULT '',
last_frame VARCHAR(128) NOT NULL DEFAULT '',
-- ========== 限制与重试 ==========
max_concurrency INT NOT NULL DEFAULT 10,
timeout_seconds INT NOT NULL DEFAULT 600,
retry_times SMALLINT NOT NULL DEFAULT 3,
auto_clean_seconds INT NOT NULL DEFAULT 86400,
-- ========== 其他 ==========
response_token_field VARCHAR(128) NOT NULL DEFAULT '',
CREATE TABLE IF NOT EXISTS model_gateway_models (
id int8 PRIMARY KEY,
tenant_id int8 NOT NULL DEFAULT 0,
creator varchar(64) NOT NULL,
created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater varchar(64) NOT NULL,
updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
model_name varchar(128) NOT NULL,
model_type int2 NOT NULL DEFAULT 0,
operator_name varchar(64) NOT NULL DEFAULT '',
base_url varchar(256) NOT NULL,
http_method varchar(8) NOT NULL DEFAULT 'POST',
head_msg jsonb NOT NULL DEFAULT '{}',
api_key varchar(256) NOT NULL DEFAULT '',
is_private int2 NOT NULL DEFAULT 0,
enabled int2 NOT NULL DEFAULT 1,
is_chat_model int2 NOT NULL DEFAULT 0,
is_owner int2 NOT NULL DEFAULT 99,
form_json jsonb NOT NULL DEFAULT '{}',
request_mapping jsonb NOT NULL DEFAULT '{}',
response_mapping jsonb NOT NULL DEFAULT '{}',
response_body varchar(128) NOT NULL DEFAULT '',
token_config jsonb NOT NULL DEFAULT '{}',
extend_mapping jsonb NOT NULL DEFAULT '{}',
query_config jsonb NOT NULL DEFAULT '{}',
stream_config jsonb NOT NULL DEFAULT '{}',
first_frame varchar(128) NOT NULL DEFAULT '',
last_frame varchar(128) NOT NULL DEFAULT '',
max_concurrency int4 NOT NULL DEFAULT 10,
timeout_seconds int4 NOT NULL DEFAULT 600,
retry_times int2 NOT NULL DEFAULT 3,
auto_clean_seconds int4 NOT NULL DEFAULT 86400,
response_token_field varchar(128) NOT NULL DEFAULT '',
call_mode int2 NOT NULL DEFAULT 0,
required_fields jsonb NOT NULL DEFAULT '[]',
max_tokens int4 DEFAULT 0
);
-- ========== 索引 ==========
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
CREATE INDEX IF NOT EXISTS idx_asynch_models_model_name ON asynch_models(model_name);
CREATE INDEX IF NOT EXISTS idx_asynch_models_model_type ON asynch_models(model_type);
CREATE INDEX IF NOT EXISTS idx_asynch_models_enabled ON asynch_models(enabled);
CREATE INDEX IF NOT EXISTS idx_asynch_models_deleted_at ON asynch_models(deleted_at);
CREATE UNIQUE INDEX IF NOT EXISTS uk_model_gateway_models_tenant_creator_model ON model_gateway_models (tenant_id, creator, model_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_models_model_name ON model_gateway_models (model_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_models_model_type ON model_gateway_models (model_type);
CREATE INDEX IF NOT EXISTS idx_model_gateway_models_tenant_id ON model_gateway_models (tenant_id);
CREATE INDEX IF NOT EXISTS idx_model_gateway_models_deleted_at ON model_gateway_models (deleted_at);
CREATE INDEX IF NOT EXISTS idx_model_gateway_models_enabled ON model_gateway_models (enabled);
-- ========== 注释 ==========
COMMENT ON TABLE asynch_models IS '模型配置表';
COMMENT ON TABLE model_gateway_models IS '模型配置表';
COMMENT ON COLUMN model_gateway_models.id IS '主键ID(非自增)';
COMMENT ON COLUMN model_gateway_models.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_models.creator IS '创建人';
COMMENT ON COLUMN model_gateway_models.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
COMMENT ON COLUMN model_gateway_models.base_url IS '模型地址';
COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN asynch_models.id IS '主键ID(非自增)';
COMMENT ON COLUMN asynch_models.tenant_id IS '租户ID';
COMMENT ON COLUMN asynch_models.creator IS '创建人';
COMMENT ON COLUMN asynch_models.created_at IS '创建时间';
COMMENT ON COLUMN asynch_models.updater IS '更新人';
COMMENT ON COLUMN asynch_models.updated_at IS '更新时间';
COMMENT ON COLUMN asynch_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN asynch_models.model_name IS '模型名称';
COMMENT ON COLUMN asynch_models.model_type IS '模型类型';
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称';
COMMENT ON COLUMN asynch_models.base_url IS '模型地址';
COMMENT ON COLUMN asynch_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN asynch_models.head_msg IS '请求头信息';
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN asynch_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN asynch_models.is_async IS '是否异步0-同步 1-异步';
COMMENT ON COLUMN asynch_models.is_stream IS '是否流式0-非流式 1-流式';
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN asynch_models.form_json IS '动态表单结构';
COMMENT ON COLUMN asynch_models.request_mapping IS '请求映射';
COMMENT ON COLUMN asynch_models.response_mapping IS '返回映射';
COMMENT ON COLUMN asynch_models.response_body IS '返回主体';
COMMENT ON COLUMN asynch_models.token_config IS 'Token计算配置';
COMMENT ON COLUMN asynch_models.extend_mapping IS '附加映射';
COMMENT ON COLUMN asynch_models.query_config IS '查询/回调配置';
COMMENT ON COLUMN asynch_models.stream_config IS '流式输出配置';
COMMENT ON COLUMN asynch_models.first_frame IS '首帧图片参数';
COMMENT ON COLUMN asynch_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN asynch_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
-- =========================
-- 2) asynch_task
-- =========================
CREATE TABLE IF NOT EXISTS asynch_task (
-- 基础字段
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL, -- 创建人
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updater VARCHAR(64) NOT NULL, -- 更新人
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
deleted_at TIMESTAMP(6), -- 删除时间(软删)
-- 业务字段
model_name VARCHAR(128) NOT NULL, -- 模型名称
task_id VARCHAR(64) NOT NULL, -- 任务ID(对外返回)
biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 业务名称(调用方模块/系统)
callback_url VARCHAR(512) DEFAULT '', -- 回调地址(可选,用于后续业务通知)
model_key VARCHAR(1024) DEFAULT '', -- 动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx
state SMALLINT NOT NULL DEFAULT 0, -- 0排队中/1执行中/2成功/3失败/4已下载
oss_file VARCHAR(512) DEFAULT '', -- 结果文件OSS地址
file_type VARCHAR(32) DEFAULT '', -- 文件类型(mp3/mp4/png/...)
file_size BIGINT NOT NULL DEFAULT 0, -- 文件大小(字节)
error_msg TEXT DEFAULT '', -- 错误信息
started_at TIMESTAMP, -- 开始执行时间
finished_at TIMESTAMP, -- 执行结束时间
duration_seconds BIGINT NOT NULL DEFAULT 0, -- 耗时(秒):从创建到完成(成功/失败)整体耗时
expire_at TIMESTAMP, -- state=4 后写入,用于清理
retry_count INT NOT NULL DEFAULT 0, -- 已重试次数(不含首次)
enqueue_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 入队时间(用于排队顺序)
phase SMALLINT NOT NULL DEFAULT 0, -- 0模型阶段/1OSS阶段
tmp_file TEXT DEFAULT '', -- 临时结果文件路径(phase=1 时仅重试 OSS 上传)
input_ref TEXT DEFAULT '', -- 输入引用(如OSS/业务资源ID等)
request_payload JSONB, -- 请求参数(可选)
text_result TEXT DEFAULT '', -- 文本类结果(可选,支持直接回调)
epicycle_id VARCHAR(64) DEFAULT '', -- 轮次ID
expend_tokens BIGINT NOT NULL DEFAULT 0 -- 消耗 token 数
);
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_task_tenant_task_id ON asynch_task(tenant_id, task_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_tenant_id ON asynch_task(tenant_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_name ON asynch_task(model_name);
CREATE INDEX IF NOT EXISTS idx_asynch_task_biz_name ON asynch_task(biz_name);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_key ON asynch_task(model_key);
CREATE INDEX IF NOT EXISTS idx_asynch_task_state ON asynch_task(state);
CREATE INDEX IF NOT EXISTS idx_asynch_task_enqueue_at ON asynch_task(enqueue_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_updated_at ON asynch_task(updated_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expire_at ON asynch_task(expire_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_deleted_at ON asynch_task(deleted_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_epicycle_id ON asynch_task(epicycle_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expend_tokens ON asynch_task(expend_tokens);
COMMENT ON TABLE asynch_task IS '异步任务表';
COMMENT ON COLUMN asynch_task.id IS '主键ID(非自增)';
COMMENT ON COLUMN asynch_task.tenant_id IS '租户ID';
COMMENT ON COLUMN asynch_task.creator IS '创建人';
COMMENT ON COLUMN asynch_task.created_at IS '创建时间';
COMMENT ON COLUMN asynch_task.updater IS '更新人';
COMMENT ON COLUMN asynch_task.updated_at IS '更新时间';
COMMENT ON COLUMN asynch_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN asynch_task.model_name IS '模型名称';
COMMENT ON COLUMN asynch_task.task_id IS '任务ID(对外返回)';
COMMENT ON COLUMN asynch_task.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN asynch_task.callback_url IS '回调地址(可选,用于后续业务通知)';
COMMENT ON COLUMN asynch_task.model_key IS '动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx';
COMMENT ON COLUMN asynch_task.state IS '0排队中/1执行中/2成功/3失败/4已下载';
COMMENT ON COLUMN asynch_task.oss_file IS '结果文件OSS地址';
COMMENT ON COLUMN asynch_task.file_type IS '文件类型(mp3/mp4/png/...)';
COMMENT ON COLUMN asynch_task.file_size IS '文件大小(字节)';
COMMENT ON COLUMN asynch_task.error_msg IS '错误信息';
COMMENT ON COLUMN asynch_task.started_at IS '开始执行时间';
COMMENT ON COLUMN asynch_task.finished_at IS '执行结束时间';
COMMENT ON COLUMN asynch_task.duration_seconds IS '耗时(秒):从创建到完成(成功/失败)整体耗时';
COMMENT ON COLUMN asynch_task.expire_at IS 'state=4 后写入,用于清理';
COMMENT ON COLUMN asynch_task.retry_count IS '已重试次数(不含首次)';
COMMENT ON COLUMN asynch_task.enqueue_at IS '入队时间(用于排队顺序)';
COMMENT ON COLUMN asynch_task.phase IS '执行阶段 模型阶段/1OSS阶段(模型已成功,等待上传OSS)';
COMMENT ON COLUMN asynch_task.tmp_file IS '临时结果文件路径(phase=1 时仅重试 OSS 上传)';
COMMENT ON COLUMN asynch_task.input_ref IS '输入引用(如OSS/业务资源ID等)';
COMMENT ON COLUMN asynch_task.request_payload IS '请求参数(可选,JSON)';
COMMENT ON COLUMN asynch_task.text_result IS '文本类结果(可选,支持直接回调)';
COMMENT ON COLUMN asynch_task.epicycle_id IS '轮次ID(用于标识同一轮次的任务)';
COMMENT ON COLUMN asynch_task.expend_tokens IS '消耗 token 数';
COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
COMMENT ON COLUMN model_gateway_models.response_body IS '返回主体';
COMMENT ON COLUMN model_gateway_models.token_config IS 'Token计算配置';
COMMENT ON COLUMN model_gateway_models.extend_mapping IS '附加映射';
COMMENT ON COLUMN model_gateway_models.query_config IS '查询/回调配置';
COMMENT ON COLUMN model_gateway_models.stream_config IS '流式输出配置';
COMMENT ON COLUMN model_gateway_models.first_frame IS '首帧图片参数';
COMMENT ON COLUMN model_gateway_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN model_gateway_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数0 表示不传';
-- =========================
-- 3) logs_model_op
-- model_gateway_task
-- =========================
CREATE TABLE IF NOT EXISTS logs_model_op (
-- 基础字段
id BIGINT PRIMARY KEY,
tenant_id BIGINT NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP(6),
-- 基础审计信息
ip VARCHAR(64) DEFAULT '',
user_agent VARCHAR(256) DEFAULT '',
api_path VARCHAR(256) DEFAULT '',
http_method VARCHAR(16) DEFAULT '',
-- 业务信息
biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 调用方业务模块/系统
model_name VARCHAR(128) NOT NULL DEFAULT '',
task_id VARCHAR(64) NOT NULL DEFAULT '',
-- 统计字段
op_type VARCHAR(64) NOT NULL DEFAULT 'createTask', -- 操作类型(默认创建任务)
success SMALLINT NOT NULL DEFAULT 1, -- 1成功/0失败
error_msg TEXT DEFAULT '',
cost_ms BIGINT NOT NULL DEFAULT 0, -- 耗时(毫秒)
-- 请求/响应 JSON(用于后期统计分析)
request_payload JSONB,
response_payload JSONB
);
CREATE TABLE IF NOT EXISTS model_gateway_task (
id int8 PRIMARY KEY,
tenant_id int8 NOT NULL DEFAULT 0,
creator varchar(64) NOT NULL,
created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater varchar(64) NOT NULL,
updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
model_name varchar(128) NOT NULL,
task_id varchar(64) NOT NULL,
biz_name varchar(128) NOT NULL DEFAULT '',
callback_url varchar(512) DEFAULT '',
state int2 NOT NULL DEFAULT 0,
retry_count int4 NOT NULL DEFAULT 0,
phase int2 NOT NULL DEFAULT 0,
tmp_file text DEFAULT '',
error_msg text DEFAULT '',
result_file jsonb NOT NULL DEFAULT '{}',
request_payload jsonb NOT NULL DEFAULT '{}',
text_result jsonb NOT NULL DEFAULT '{}',
expend_tokens int8 NOT NULL DEFAULT 0,
duration_seconds int8 NOT NULL DEFAULT 0,
epicycle_id varchar(64) NOT NULL DEFAULT ''
);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_tenant_time ON logs_model_op(tenant_id, created_at);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_model_name ON logs_model_op(model_name);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_biz_name ON logs_model_op(biz_name);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_task_id ON logs_model_op(task_id);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_op_type ON logs_model_op(op_type);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_deleted_at ON logs_model_op(deleted_at);
CREATE UNIQUE INDEX IF NOT EXISTS uk_model_gateway_task_tenant_creator_task_id ON model_gateway_task (tenant_id, creator, task_id);
CREATE INDEX IF NOT EXISTS idx_model_gateway_task_task_id ON model_gateway_task (task_id);
CREATE INDEX IF NOT EXISTS idx_model_gateway_task_state ON model_gateway_task (state);
CREATE INDEX IF NOT EXISTS idx_model_gateway_task_deleted_at ON model_gateway_task (deleted_at);
COMMENT ON TABLE model_gateway_task IS '模型网关任务表';
COMMENT ON COLUMN model_gateway_task.id IS '主键ID';
COMMENT ON COLUMN model_gateway_task.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_task.creator IS '创建人';
COMMENT ON COLUMN model_gateway_task.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_task.updater IS '更新人';
COMMENT ON COLUMN model_gateway_task.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_task.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_task.task_id IS '任务ID对外返回';
COMMENT ON COLUMN model_gateway_task.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN model_gateway_task.callback_url IS '回调地址';
COMMENT ON COLUMN model_gateway_task.state IS '0排队中/1执行中/2成功/3失败/4已下载';
COMMENT ON COLUMN model_gateway_task.retry_count IS '已重试次数';
COMMENT ON COLUMN model_gateway_task.phase IS '执行阶段0模型阶段/1OSS阶段';
COMMENT ON COLUMN model_gateway_task.tmp_file IS '临时结果文件路径';
COMMENT ON COLUMN model_gateway_task.error_msg IS '错误信息';
COMMENT ON COLUMN model_gateway_task.result_file IS '结果文件:{oss_file, file_type, file_size}';
COMMENT ON COLUMN model_gateway_task.request_payload IS '请求参数JSON';
COMMENT ON COLUMN model_gateway_task.text_result IS '文本类结果';
COMMENT ON COLUMN model_gateway_task.expend_tokens IS '消耗token数';
COMMENT ON COLUMN model_gateway_task.duration_seconds IS '耗时(秒)';
COMMENT ON COLUMN model_gateway_task.epicycle_id IS '轮次ID';
COMMENT ON TABLE logs_model_op IS '操作记录日志表(创建任务等,用于统计)';
COMMENT ON COLUMN logs_model_op.id IS '主键ID(非自增)';
COMMENT ON COLUMN logs_model_op.tenant_id IS '租户ID';
COMMENT ON COLUMN logs_model_op.creator IS '创建人';
COMMENT ON COLUMN logs_model_op.created_at IS '创建时间';
COMMENT ON COLUMN logs_model_op.updater IS '更新人';
COMMENT ON COLUMN logs_model_op.updated_at IS '更新时间';
COMMENT ON COLUMN logs_model_op.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN logs_model_op.ip IS '客户端IP';
COMMENT ON COLUMN logs_model_op.user_agent IS 'User-Agent';
COMMENT ON COLUMN logs_model_op.api_path IS '接口路径';
COMMENT ON COLUMN logs_model_op.http_method IS 'HTTP方法';
COMMENT ON COLUMN logs_model_op.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN logs_model_op.model_name IS '模型名称';
COMMENT ON COLUMN logs_model_op.task_id IS '任务ID';
COMMENT ON COLUMN logs_model_op.op_type IS '操作类型(如 createTask/getTaskResult/getTaskBatch 等)';
COMMENT ON COLUMN logs_model_op.success IS '是否成功1成功/0失败';
COMMENT ON COLUMN logs_model_op.error_msg IS '错误信息(失败时)';
COMMENT ON COLUMN logs_model_op.cost_ms IS '耗时(毫秒)';
COMMENT ON COLUMN logs_model_op.request_payload IS '请求 JSON';
COMMENT ON COLUMN logs_model_op.response_payload IS '响应 JSON';
-- =========================
-- 4) logs_model_stat
-- model_gateway_log_stat
-- =========================
CREATE TABLE IF NOT EXISTS logs_model_stat (
day DATE NOT NULL, -- 天(YYYY-MM-DD)
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL DEFAULT '', -- 创建人
model_name VARCHAR(128) NOT NULL DEFAULT '', -- 模型名称
request_count BIGINT NOT NULL DEFAULT 0, -- 请求次数
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(day, tenant_id, creator, model_name)
);
CREATE TABLE IF NOT EXISTS model_gateway_log_stat (
day date NOT NULL,
tenant_id int8 NOT NULL DEFAULT 0,
creator varchar(64) NOT NULL DEFAULT '',
model_name varchar(128) NOT NULL DEFAULT '',
request_count int8 NOT NULL DEFAULT 0,
created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (day, tenant_id, creator, model_name)
);
-- 便于时间段/租户/人/模型过滤
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_tenant_day ON logs_model_stat(tenant_id, day);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_day ON logs_model_stat(day);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_model_name ON logs_model_stat(model_name);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_creator ON logs_model_stat(creator);
CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_day ON model_gateway_log_stat (day);
CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_creator ON model_gateway_log_stat (creator);
CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_model_name ON model_gateway_log_stat (model_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_tenant_day ON model_gateway_log_stat (tenant_id, day);
COMMENT ON TABLE logs_model_stat IS '按天模型请求统计(用于限流/监控)';
COMMENT ON COLUMN logs_model_stat.day IS '(YYYY-MM-DD)';
COMMENT ON COLUMN logs_model_stat.tenant_id IS '租户ID';
COMMENT ON COLUMN logs_model_stat.creator IS '创建人';
COMMENT ON COLUMN logs_model_stat.model_name IS '模型名称';
COMMENT ON COLUMN logs_model_stat.request_count IS '请求次数';
COMMENT ON COLUMN logs_model_stat.created_at IS '创建时间';
COMMENT ON COLUMN logs_model_stat.updated_at IS '更新时间';
COMMENT ON TABLE model_gateway_log_stat IS '按天统计表';
COMMENT ON COLUMN model_gateway_log_stat.day IS 'YYYY-MM-DD';
COMMENT ON COLUMN model_gateway_log_stat.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_log_stat.creator IS '创建人';
COMMENT ON COLUMN model_gateway_log_stat.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_log_stat.request_count IS '请求次数';
COMMENT ON COLUMN model_gateway_log_stat.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_log_stat.updated_at IS '更新时间';
-- =========================
-- model_gateway_logs_op
-- =========================
CREATE TABLE IF NOT EXISTS model_gateway_logs_op (
id int8 PRIMARY KEY,
tenant_id int8 NOT NULL DEFAULT 0,
creator varchar(64) NOT NULL,
created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater varchar(64) NOT NULL,
updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
ip varchar(64) DEFAULT '',
user_agent varchar(256) DEFAULT '',
api_path varchar(256) DEFAULT '',
http_method varchar(16) DEFAULT '',
biz_name varchar(128) NOT NULL DEFAULT '',
model_name varchar(128) NOT NULL DEFAULT '',
task_id varchar(64) NOT NULL DEFAULT '',
op_type varchar(64) NOT NULL DEFAULT 'createTask',
success int2 NOT NULL DEFAULT 1,
error_msg text DEFAULT '',
cost_ms int8 NOT NULL DEFAULT 0,
request_payload jsonb,
response_payload jsonb
);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_task_id ON model_gateway_logs_op (task_id);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_biz_name ON model_gateway_logs_op (biz_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_model_name ON model_gateway_logs_op (model_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_op_type ON model_gateway_logs_op (op_type);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_deleted_at ON model_gateway_logs_op (deleted_at);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_tenant_time ON model_gateway_logs_op (tenant_id, created_at);
COMMENT ON TABLE model_gateway_logs_op IS '操作日志表';
COMMENT ON COLUMN model_gateway_logs_op.id IS '主键ID非自增';
COMMENT ON COLUMN model_gateway_logs_op.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_logs_op.creator IS '创建人';
COMMENT ON COLUMN model_gateway_logs_op.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_logs_op.updater IS '更新人';
COMMENT ON COLUMN model_gateway_logs_op.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_logs_op.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_logs_op.ip IS '客户端IP';
COMMENT ON COLUMN model_gateway_logs_op.user_agent IS 'User-Agent';
COMMENT ON COLUMN model_gateway_logs_op.api_path IS '接口路径';
COMMENT ON COLUMN model_gateway_logs_op.http_method IS 'HTTP方法';
COMMENT ON COLUMN model_gateway_logs_op.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN model_gateway_logs_op.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_logs_op.task_id IS '任务ID';
COMMENT ON COLUMN model_gateway_logs_op.op_type IS '操作类型';
COMMENT ON COLUMN model_gateway_logs_op.success IS '是否成功1成功/0失败';
COMMENT ON COLUMN model_gateway_logs_op.error_msg IS '错误信息(失败时)';
COMMENT ON COLUMN model_gateway_logs_op.cost_ms IS '耗时(毫秒)';
COMMENT ON COLUMN model_gateway_logs_op.request_payload IS '请求 JSON';
COMMENT ON COLUMN model_gateway_logs_op.response_payload IS '响应 JSON';