Compare commits

2 Commits

34 changed files with 807 additions and 1214 deletions

View File

@@ -21,7 +21,7 @@ import (
) )
// ParseAndValidate 解析并校验结果 // ParseAndValidate 解析并校验结果
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) { func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组 // 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody] contentVal, ok := raw[model.ResponseBody]
if !ok { if !ok {
@@ -94,53 +94,6 @@ func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
} }
} }
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取 rounds
roundsRaw, ok := raw["rounds"]
if !ok {
return fmt.Errorf("缺少 rounds 字段")
}
rounds, ok := roundsRaw.([]any)
if !ok {
return fmt.Errorf("rounds 不是数组")
}
if len(rounds) == 0 {
return fmt.Errorf("rounds 数组为空")
}
// 2) 没有配置必填字段,跳过
if len(model.RequiredFields) == 0 {
return nil
}
// 3) 逐条校验
for i, r := range rounds {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
return nil
}
// validateRequiredFields 校验单个 round 对象的必选字段
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
for _, field := range requiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
}
}
return nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例: // head_msg 格式示例:
// //
@@ -198,16 +151,17 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
return mapped, nil return mapped, nil
} }
// GetModelBody 获取数据库中保存的模型信息 //
func GetModelBody(v map[string]any) map[string]any { //// GetModelBody 获取数据库中保存的模型信息
if v == nil { //func GetModelBody(v map[string]any) map[string]any {
return nil // if v == nil {
} // return nil
if p, ok := v["body"]; ok { // }
return gconv.Map(p) // if p, ok := v["body"]; ok {
} // return gconv.Map(p)
return v // }
} // return v
//}
// BodyToQuery 将 body 转为 url.Values // BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) { func BodyToQuery(payload map[string]any) (url.Values, error) {

View File

@@ -6,6 +6,14 @@ const (
CallModeStream = 2 // 流式调用 CallModeStream = 2 // 流式调用
) )
const (
TaskStatusPending = 0 // 排队中
TaskStatusRunning = 1 // 执行中
TaskStatusSuccess = 2 // 成功
TaskStatusFailed = 3 // 失败
TaskStatusDownloaded = 4 // 已下载
)
const ( const (
BuildTypePrompt = 1 //提示词构建 BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建 BuildTypeNode = 2 //节点构建

View File

@@ -5,8 +5,8 @@ const (
) )
const ( const (
TableNameModel = "asynch_models" // 模型表 TableNameModel = "model_gateway_models" // 模型表
TableNameTask = "asynch_task" // 任务表 TableNameTask = "model_gateway_task" // 任务表
TableNameOpLog = "logs_model_op" // 操作日志表 TableNameOpLog = "model_gateway_logs_op" // 操作日志表
TableNameStat = "logs_model_stat" // 按天统计表(请求次数) TableNameStat = "model_gateway_logs_stat" // 按天统计表
) )

View File

@@ -7,12 +7,12 @@ import (
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type stat struct{} // ModelGatewayLogsStat 统计控制器
var ModelGatewayLogsStat = new(stat)
// Stat 统计控制器 type stat struct{}
var Stat = new(stat)
// ListModelStat 统计列表 // ListModelStat 统计列表
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
return statService.Stat.List(ctx, req) return statService.ModelGatewayLogsStat.List(ctx, req)
} }

View File

@@ -7,36 +7,36 @@ import (
"model-gateway/service/queue" "model-gateway/service/queue"
) )
type model struct{} // ModelGatewayModels 模型配置控制器
var ModelGatewayModels = new(model)
// Model 模型配置控制器 type model struct{}
var Model = new(model)
// CreateModel 添加配置 // CreateModel 添加配置
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
return modelService.Model.Create(ctx, req) return modelService.ModelGatewayModels.Create(ctx, req)
} }
// UpdateModel 更改配置 // UpdateModel 更改配置
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) { func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
err = modelService.Model.Update(ctx, req) err = modelService.ModelGatewayModels.Update(ctx, req)
return return
} }
// DeleteModel 删除配置 // DeleteModel 删除配置
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) { func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
err = modelService.Model.Delete(ctx, req) err = modelService.ModelGatewayModels.Delete(ctx, req)
return return
} }
// GetModel 获取配置详情 // GetModel 获取配置详情
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) { func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
return modelService.Model.Get(ctx, req) return modelService.ModelGatewayModels.Get(ctx, req)
} }
// ListModel 配置列表 // ListModel 配置列表
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
return modelService.Model.List(ctx, req) return modelService.ModelGatewayModels.List(ctx, req)
} }
// AutoTune 动态调参(由上层定时任务每小时触发一次) // AutoTune 动态调参(由上层定时任务每小时触发一次)
@@ -56,11 +56,11 @@ func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res
// UpdateChatModel 更新是否为聊天模型 // UpdateChatModel 更新是否为聊天模型
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) { func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
err = modelService.Model.UpdateChatModel(ctx, req) err = modelService.ModelGatewayModels.UpdateChatModel(ctx, req)
return return
} }
// GetIsChatModel 获取当前会话模型 // GetIsChatModel 获取当前会话模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) { func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
return modelService.Model.GetIsChatModel(ctx) return modelService.ModelGatewayModels.GetIsChatModel(ctx)
} }

View File

@@ -2,48 +2,42 @@ package controller
import ( import (
"context" "context"
"model-gateway/service/job"
taskService "model-gateway/service/task" taskService "model-gateway/service/task"
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type task struct{} // ModelGatewayTask 任务控制器
var ModelGatewayTask = new(task)
// Task 任务控制器 type task struct{}
var Task = new(task)
// CreateTask 根据 modelName 创建异步任务,返回 taskId // CreateTask 根据 modelName 创建异步任务,返回 taskId
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
return taskService.Task.Create(ctx, req) return taskService.ModelGatewayTask.Create(ctx, req)
} }
// ModelTaskCallback 接收模型异步任务的回调通知 // GetTaskResult 获取单条任务结果(返回 *dto.GetTaskResultRes
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return taskService.Task.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.Task.QueryPendingTasks(ctx, req)
}
// GetTaskResult 获取任务结果(只返回 oss 地址 + state
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) { func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
return taskService.Task.GetResult(ctx, req.TaskID) return taskService.ModelGatewayTask.GetResult(ctx, req.TaskID)
} }
// GetTaskBatch 批量查询任务(成功任务标记为已下载 // GetTaskBatch 批量查询任务(返回 *[]dto.GetTaskBatchItem
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
return taskService.Task.GetBatch(ctx, req) return taskService.ModelGatewayTask.GetBatch(ctx, req)
} }
// ListTask 任务列表分页查询 // ListTask 任务列表分页查询
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
return taskService.Task.List(ctx, req) return taskService.ModelGatewayTask.List(ctx, req)
} }
// CleanWork 手动触发一次 cleaner由上层定时任务调用 // ModelTaskCallback 接收模型异步任务的回调通知 —— 待调整
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) { func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return job.Cleaner.RunOnce(ctx) return taskService.ModelGatewayTask.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务 —— 待调整
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.ModelGatewayTask.QueryPendingTasks(ctx, req)
} }

View File

@@ -0,0 +1,23 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var ModelGatewayLogsOp = &modelGatewayLogsOpDao{}
type modelGatewayLogsOpDao struct{}
// Insert 插入操作日志
func (d *modelGatewayLogsOpDao) Insert(ctx context.Context, req *entity.ModelGatewayLogsOp) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).Insert(req)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -0,0 +1,52 @@
package dao
import (
"context"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayLogsStat = &modelGatewayLogsStatDao{}
type modelGatewayLogsStatDao struct{}
// IncRequestCount 原子累加:按天+租户+创建人+模型 +1
func (d *modelGatewayLogsStatDao) IncRequestCount(ctx context.Context, day time.Time, tenantId uint64, creator, modelName string) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
Data(&entity.ModelGatewayLogsStat{
Day: gtime.New(day),
TenantId: tenantId,
Creator: creator,
ModelName: modelName,
RequestCount: 1,
}).
OnDuplicate("request_count", "request_count+1").
Insert()
return err
}
// List 分页查询统计
func (d *modelGatewayLogsStatDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayLogsStat) (list []*entity.ModelGatewayLogsStat, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
OmitEmpty().
Where(entity.ModelGatewayLogsStatCols.Creator, req.Creator).
WhereLike(entity.ModelGatewayLogsStatCols.ModelName, "%"+req.ModelName+"%").
OrderDesc(entity.ModelGatewayLogsStatCols.Day).
OrderDesc(entity.ModelGatewayLogsStatCols.RequestCount)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -9,71 +9,64 @@ import (
"gitea.redpowerfuture.com/red-future/common/db/gfdb" "gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
) )
var Model = &modelDao{} var ModelGatewayModels = &modelGatewayModelsDao{}
type modelDao struct{} type modelGatewayModelsDao struct{}
// Insert 插入 // Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) { func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
m := new(entity.AsynchModel) r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
err = gconv.Struct(req, &m)
if err != nil { if err != nil {
return return 0, err
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return
} }
return r.LastInsertId() return r.LastInsertId()
} }
// Update 更新 // Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Data(&req). Data(req).
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Update() Update()
if err != nil { if err != nil {
return return 0, err
} }
return r.RowsAffected() return r.RowsAffected()
} }
// Delete 删除 // Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Delete() Delete()
if err != nil { if err != nil {
return return 0, err
} }
return r.RowsAffected() return r.RowsAffected()
} }
// Get 获取模型 // Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return return nil, err
} }
var m entity.ModelGatewayModel
err = r.Struct(&m) err = r.Struct(&m)
return return &m, err
} }
//// Get 按ID获取带租户隔离只查当前租户 //// Get 按ID获取带租户隔离只查当前租户
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { //func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// var whereCondition strings.Builder // var whereCondition strings.Builder
// var queryParams []interface{} // var queryParams []interface{}
// if !g.IsEmpty(req.Id) { // if !g.IsEmpty(req.Id) {
@@ -108,25 +101,25 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
// return // return
//} //}
// GetByAcrossTenant 按ID获取跨租户查所有租户 // GetByAcrossTenant 跨租户查询
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx). NoTenantId(ctx).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return return nil, err
} }
var m entity.ModelGatewayModel
err = r.Struct(&m) err = r.Struct(&m)
return return &m, err
} }
// GetByCreatorAndPlatform 按创建者、平台获取 // GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := ` sql := `
SELECT DISTINCT ON (model_name) * SELECT DISTINCT ON (model_name) *
FROM asynch_models FROM asynch_models
@@ -186,7 +179,7 @@ WHERE deleted_at IS NULL
} }
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文 // GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1", "SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName, tenantId, modelName,
@@ -197,7 +190,7 @@ func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64,
if r.IsEmpty() { if r.IsEmpty() {
return nil, nil return nil, nil
} }
var list []*entity.AsynchModel var list []*entity.ModelGatewayModel
if err := r.Structs(&list); err != nil { if err := r.Structs(&list); err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,159 @@
package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayTask = &modelGatewayTaskDao{}
type modelGatewayTaskDao struct{}
// Insert 插入
func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) {
m := new(entity.ModelGatewayTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新按ID
func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(req).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取按TaskID 或 ID
func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// List 分页查询
func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.Creator, req.Creator).
Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%").
Where(entity.ModelGatewayTaskCol.BizName, req.BizName).
Where(entity.ModelGatewayTaskCol.State, req.State).
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
OrderDesc(entity.ModelGatewayTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// Delete 删除软删按ID
func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Delete()
if err != nil {
return
}
return r.RowsAffected()
}
// ListByTaskIDs 批量查询
func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// MarkDownloadedByID 标记已下载
func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, 2).
Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}).
Update()
return err
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) {
var tasks []*entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.State, 1).
Limit(limit).
Scan(&tasks)
return tasks, err
}
// ======================== 事务抢占 ========================
// ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil {
return nil, err
}
return &task, nil
}

View File

@@ -1,30 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
type opLogDao struct{}
var OpLog = &opLogDao{}
// Insert 插入
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
m := new(entity.LogsModelOp)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
Insert(m)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -1,60 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
)
type statDao struct{}
var Stat = &statDao{}
// IncRequestCount 原子累加(支持分布式/多协程):按天+租户+创建人+模型 +1
func (d *statDao) IncRequestCount(ctx context.Context, day time.Time, tenantId int64, creator, modelName string) error {
sql := fmt.Sprintf(`
INSERT INTO %s(day, tenant_id, creator, model_name, request_count, created_at, updated_at)
VALUES(?, ?, ?, ?, 1, NOW(), NOW())
ON CONFLICT (day, tenant_id, creator, model_name)
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
public.TableNameStat, public.TableNameStat,
)
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
return err
}
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) {
m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1")
if startDay != "" {
m = m.Where("day >= ?", startDay)
}
if endDay != "" {
m = m.Where("day <= ?", endDay)
}
if tenantId != nil {
m = m.Where("tenant_id = ?", *tenantId)
}
if creator != "" {
m = m.WhereLike("creator", "%"+creator+"%")
}
if modelName != "" {
m = m.WhereLike("model_name", "%"+modelName+"%")
}
m = m.OrderDesc("day").OrderDesc("request_count")
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -1,124 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = &taskDao{}
type taskDao struct{}
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(&req).
Where(entity.AsynchTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&m)
return
}
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
data := gdb.Map{
entity.AsynchTaskCol.State: 4,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
Update()
return err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
if modelNameLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
}
if taskIDLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
}
if state != nil {
m = m.Where(entity.AsynchTaskCol.State, *state)
}
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where("state", 1).
Where("deleted_at IS NULL").
Limit(limit).
Scan(&tasks)
return tasks, err
}

View File

@@ -1,214 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ======================== 查询辅助 ========================
// taskColumns 查询用的公共字段
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 事务抢占 ========================
// claimTasks 事务内抢占任务并更新 state=1
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
if batchSize <= 0 {
batchSize = 1
}
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
if err != nil {
return err
}
}
return nil
})
return tasks, err
}
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
if taskID == "" {
return nil, nil
}
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
}
// ======================== 更新辅助 ========================
func execSQL(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
return err
}
// updateTask 通用更新
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
return err
}
// UpdateSuccessGlobal 更新任务成功
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 2,
OssFile: t.OssFile,
FileType: t.FileType,
TextResult: t.TextResult,
FileSize: t.FileSize,
ErrorMsg: "",
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
ExpendTokens: t.ExpendTokens,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedGlobal 模型调用失败
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 3,
ErrorMsg: t.ErrorMsg,
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
TextResult: t.TextResult,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedKeepTmpGlobal OSS 上传失败
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
}
// UpdateTmpAfterModelGlobal 写临时文件
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
}
// RollbackToPendingGlobal 回滚
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
}
// IncRetryCountGlobal 重试计数+1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
}
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
}
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
}
// ListFailedRetryableGlobal
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListFailedExhaustedGlobal
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListTimeoutTasksGlobal
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// HardDeleteByIDGlobal
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
}
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
err = r.Structs(&list)
return list, err
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}
// UpdateColumns 更新指定字段(结构体版)
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}

29
main.go
View File

@@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"model-gateway/model/dto" "model-gateway/model/dto"
"model-gateway/service/job"
"model-gateway/service/task" "model-gateway/service/task"
"os" "os"
"os/signal" "os/signal"
@@ -27,9 +26,9 @@ func main() {
// 注册路由 // 注册路由
http.RouteRegister([]interface{}{ http.RouteRegister([]interface{}{
controller.Model, controller.ModelGatewayModels,
controller.Task, controller.ModelGatewayTask,
controller.Stat, controller.ModelGatewayLogsStat,
}) })
// 本地调试:可选自动触发 worker/cleaner由配置文件控制 // 本地调试:可选自动触发 worker/cleaner由配置文件控制
@@ -47,26 +46,6 @@ func main() {
} }
func startAutoRunner(ctx context.Context) { func startAutoRunner(ctx context.Context) {
// cleaner
if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int()
if interval <= 0 {
interval = 30
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_, _ = job.Cleaner.RunOnce(ctx)
}
}
}()
}
// queryPending // queryPending
if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() { if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int() interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int()
@@ -79,7 +58,7 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil { if _, err := task.ModelGatewayTask.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err) g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err)
} }
} }

View File

@@ -0,0 +1,20 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -103,7 +103,7 @@ type GetModelReq struct {
} }
type GetModelRes struct { type GetModelRes struct {
Model *entity.AsynchModel `json:"model" dc:"模型配置详情"` Model *entity.ModelGatewayModel `json:"model" dc:"模型配置详情"`
} }
// ListModelReq 配置列表 // ListModelReq 配置列表

View File

@@ -1,6 +1,8 @@
package dto package dto
import "github.com/gogf/gf/v2/frame/g" import (
"github.com/gogf/gf/v2/frame/g"
)
// CreateTaskReq 创建异步任务 // CreateTaskReq 创建异步任务
type CreateTaskReq struct { type CreateTaskReq struct {
@@ -8,7 +10,6 @@ type CreateTaskReq struct {
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"` ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"` BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"` CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用如OSS/文件引用等)"`
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"` BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"`
@@ -71,6 +72,7 @@ type GetTaskBatchItem struct {
TaskID string `json:"taskId" dc:"任务ID"` TaskID string `json:"taskId" dc:"任务ID"`
State int `json:"state" dc:"任务状态"` State int `json:"state" dc:"任务状态"`
OssFile string `json:"ossFile" dc:"结果文件OSS地址"` OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
TextResult map[string]any `json:"textResult" dc:"文本结果"`
} }
type GetTaskBatchRes struct { type GetTaskBatchRes struct {
@@ -83,8 +85,9 @@ type ListTaskReq struct {
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"` PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"` PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"` ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称"`
TaskID string `p:"taskId" json:"taskId" dc:"任务ID模糊匹配"` TaskID string `p:"taskId" json:"taskId" dc:"任务ID模糊匹配"`
State *int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"` State int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"`
} }
type ListTaskRes struct { type ListTaskRes struct {
@@ -102,12 +105,3 @@ type RunWorkReq struct {
type RunWorkRes struct { type RunWorkRes struct {
Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"` Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"`
} }
// CleanWorkReq 手动触发 cleaner 执行一次(由上层定时任务调用)
type CleanWorkReq struct {
g.Meta `path:"/cleanWork" method:"post" tags:"任务管理" summary:"执行一次Cleaner" dc:"手动触发一次清理/重试(用于由上层定时任务控制)"`
}
type CleanWorkRes struct {
Ok bool `json:"ok" dc:"是否执行成功"`
}

View File

@@ -1,20 +0,0 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -1,89 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/os/gtime"
)
type asynchTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
ModelKey string
State string
OssFile string
FileType string
FileSize string
ErrorMsg string
StartedAt string
FinishedAt string
DurationSeconds string
ExpireAt string
RetryCount string
EnqueueAt string
Phase string
TmpFile string
InputRef string
RequestPayload string
TextResult string
EpicycleId string
ExpendTokens string
}
var AsynchTaskCol = asynchTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
ModelKey: "model_key",
State: "state",
OssFile: "oss_file",
FileType: "file_type",
FileSize: "file_size",
ErrorMsg: "error_msg",
StartedAt: "started_at",
FinishedAt: "finished_at",
DurationSeconds: "duration_seconds",
ExpireAt: "expire_at",
RetryCount: "retry_count",
EnqueueAt: "enqueue_at",
Phase: "phase",
TmpFile: "tmp_file",
InputRef: "input_ref",
RequestPayload: "request_payload",
TextResult: "text_result",
EpicycleId: "epicycle_id",
ExpendTokens: "expend_tokens",
}
// AsynchTask 异步任务
type AsynchTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
ModelKey string `orm:"model_key" json:"modelKey"`
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
FileSize int64 `orm:"file_size" json:"fileSize"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
RetryCount int `orm:"retry_count" json:"retryCount"`
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
InputRef string `orm:"input_ref" json:"inputRef"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
TextResult map[string]any `orm:"text_result" json:"text"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID用于标识同一轮次的任务
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
}

View File

@@ -1,57 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type LogsModelPpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var LogsModelOpCol = LogsModelPpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// LogsModelOp 操作日志(创建任务等)
type LogsModelOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
ResponsePayload any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -1,38 +0,0 @@
package entity
import (
"github.com/gogf/gf/v2/os/gtime"
)
// LogsModelStatCol 字段常量
type LogsModelStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var LogsModelStatCols = LogsModelStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// LogsModelStat 按天统计:某天/租户/创建人/模型的请求次数
// 注:这里不走通用 SQLBaseDO采用联合唯一键day,tenant_id,creator,model_name做 UPSERT 原子累加。
type LogsModelStat struct {
Day *gtime.Time `orm:"day" json:"day"` // 日期(建议仅使用日期部分)
TenantId int64 `orm:"tenant_id" json:"tenantId"` // 租户ID
Creator string `orm:"creator" json:"creator"` // 创建人/操作人
ModelName string `orm:"model_name" json:"modelName"` // 模型名称
RequestCount int64 `orm:"request_count" json:"requestCount"` // 请求次数
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` // 创建时间
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` // 更新时间
}

View File

@@ -0,0 +1,56 @@
package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
// ModelGatewayLogsOpCol 字段常量
type modelGatewayLogsOpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var ModelGatewayLogsOpCol = modelGatewayLogsOpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// ModelGatewayLogsOp 操作日志
type ModelGatewayLogsOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
ResponsePayload map[string]any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -0,0 +1,35 @@
package entity
import "github.com/gogf/gf/v2/os/gtime"
// ModelGatewayLogsStatCol 字段常量
type ModelGatewayLogsStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var ModelGatewayLogsStatCols = ModelGatewayLogsStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// ModelGatewayLogsStat 按天统计
type ModelGatewayLogsStat struct {
Day *gtime.Time `orm:"day" json:"day"`
TenantId uint64 `orm:"tenant_id" json:"tenantId"`
Creator string `orm:"creator" json:"creator"`
ModelName string `orm:"model_name" json:"modelName"`
RequestCount int64 `orm:"request_count" json:"requestCount"`
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"`
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"`
}

View File

@@ -2,7 +2,7 @@ package entity
import "gitea.redpowerfuture.com/red-future/common/beans" import "gitea.redpowerfuture.com/red-future/common/beans"
type asynchModelCol struct { type modelGatewayModelCol struct {
beans.SQLBaseCol beans.SQLBaseCol
ModelName string ModelName string
ModelType string ModelType string
@@ -33,9 +33,10 @@ type asynchModelCol struct {
FirstFrame string FirstFrame string
LastFrame string LastFrame string
CallbackUrl string CallbackUrl string
MaxTokens string
} }
var AsynchModelCol = asynchModelCol{ var ModelGatewayModelCol = modelGatewayModelCol{
SQLBaseCol: beans.DefSQLBaseCol, SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name", ModelName: "model_name",
ModelType: "model_type", ModelType: "model_type",
@@ -66,10 +67,10 @@ var AsynchModelCol = asynchModelCol{
FirstFrame: "first_frame", FirstFrame: "first_frame",
LastFrame: "last_frame", LastFrame: "last_frame",
CallbackUrl: "callback_url", CallbackUrl: "callback_url",
MaxTokens: "max_tokens",
} }
// AsynchModel 异步模型配置 type ModelGatewayModel struct {
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"` beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"` ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"` ModelType int `orm:"model_type" json:"modelType"`
@@ -80,7 +81,7 @@ type AsynchModel struct {
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"` ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` ResponseTokenField string `orm:"response_token_field" json:"tokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"` RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"` IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
@@ -91,7 +92,7 @@ type AsynchModel struct {
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"` RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"` IsOwner *int `orm:"is_owner" json:"isOwner"`
OperatorName string `orm:"operator_name" json:"operatorName"` OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
@@ -100,4 +101,5 @@ type AsynchModel struct {
FirstFrame string `orm:"first_frame" json:"firstFrame"` FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"` LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"` CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
MaxTokens int `orm:"max_tokens" json:"maxTokens"`
} }

View File

@@ -0,0 +1,76 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type modelGatewayTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
State string
Phase string
ErrorMsg string
ResultFile string
TextResult string
ExpendTokens string
DurationSeconds string
RetryCount string
TmpFile string
RequestPayload string
EpicycleId string
}
var ModelGatewayTaskCol = modelGatewayTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
State: "state",
Phase: "phase",
ErrorMsg: "error_msg",
ResultFile: "result_file",
TextResult: "text_result",
ExpendTokens: "expend_tokens",
DurationSeconds: "duration_seconds",
RetryCount: "retry_count",
TmpFile: "tmp_file",
RequestPayload: "request_payload",
EpicycleId: "epicycle_id",
}
// ModelGatewayTask 模型网关任务
type ModelGatewayTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
State int `orm:"state" json:"state"`
Phase int `orm:"phase" json:"phase"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
ResultFile *ResultFile `orm:"result_file" json:"resultFile"`
TextResult map[string]any `orm:"text_result" json:"text"`
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
RetryCount int `orm:"retry_count" json:"retryCount"`
TmpFile string `orm:"tmp_file" json:"tmpFile"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"`
}
// ResultFile OSS 结果文件
type ResultFile struct {
OssFile string `json:"ossFile"`
FileType string `json:"fileType"`
FileSize int64 `json:"fileSize"`
}
// RequestPayload 请求参数结构体
type RequestPayload struct {
Headers map[string]string `json:"headers"`
Body map[string]any `json:"body"`
}

View File

@@ -77,14 +77,14 @@ type CallbackPayload struct {
} }
// TriggerCallback 任务的回调 // TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp struct{} var resp struct{}
payload := CallbackPayload{ payload := CallbackPayload{
TaskId: t.TaskID, TaskId: t.TaskID,
State: t.State, State: t.State,
OssFile: t.OssFile, OssFile: t.ResultFile.OssFile,
FileType: t.FileType, FileType: t.ResultFile.FileType,
Messages: t.TextResult, Messages: t.TextResult,
ErrorMsg: t.ErrorMsg, ErrorMsg: t.ErrorMsg,
} }
@@ -111,7 +111,7 @@ type PromptsCallbackPayload struct {
} }
// TriggerPromptsCallback 任务成功后的提示词回调 // TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback" callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp struct{} var resp struct{}

View File

@@ -1,99 +0,0 @@
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
"model-gateway/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
// 获取模型配置以得到 queue_limit / expected_seconds
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || m == nil {
continue
}
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
if limit > 0 {
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
if !ok {
continue
}
}
// retry_queue_max_seconds 控制失败重试的排队策略:
// - =0失败重试插队到队首
// - >0当任务从创建到现在的排队时长 >= maxSeconds则插队到队首否则仍放到队尾
now := time.Now()
enqueueAt := now
maxSeconds := t.RetryQueueMaxSeconds
if maxSeconds == 0 {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
} else if maxSeconds > 0 && t.CreatedAt != nil {
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
}
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -18,7 +18,7 @@ import (
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
var Model = &modelService{} var ModelGatewayModels = &modelService{}
type modelService struct{} type modelService struct{}
@@ -37,7 +37,7 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dt
} }
// 3入库 // 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) id, err := dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -56,27 +56,27 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
req.IsOwner = gconv.PtrInt(1) req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0) req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err := dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
// 3跨租户判断超管的模型不允许直接修改走插入新记录 // 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
}) })
if err != nil { if err != nil {
return err return err
} }
if model.TenantId == 1 { if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err = dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err = dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
// Delete 删除模型 // Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{ _, err := dao.ModelGatewayModels.Delete(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
}) })
return err return err
@@ -91,7 +91,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
if g.IsEmpty(req.ID) { if g.IsEmpty(req.ID) {
req.Creator = user.UserName req.Creator = user.UserName
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
Id: req.ID, Id: req.ID,
Creator: user.UserName, Creator: user.UserName,
@@ -123,7 +123,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
req.Creator = user.UserName req.Creator = user.UserName
// 3查询 // 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) models, total, err := dao.ModelGatewayModels.GetByCreatorAndPlatform(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -134,7 +134,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
// UpdateChatModel 设置会话模型 // UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在 // 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ newModel, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
}) })
if err != nil || newModel == nil { if err != nil || newModel == nil {
@@ -146,7 +146,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
if err != nil { if err != nil {
return err return err
} }
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ currentModel, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -161,7 +161,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return errors.New("当前模型为非推理模型,不能设置为会话模型") return errors.New("当前模型为非推理模型,不能设置为会话模型")
} }
if currentModel.Id != req.Id { if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0), IsChatModel: gconv.PtrInt(0),
}) })
@@ -171,7 +171,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
} }
} }
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -185,7 +185,7 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR
if err != nil { if err != nil {
return nil, err return nil, err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -203,14 +203,14 @@ func (s *modelService) clearUserChatModel(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
if err != nil || model == nil { if err != nil || model == nil {
return nil return nil
} }
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0), IsChatModel: gconv.PtrInt(0),
}) })
@@ -223,7 +223,7 @@ func (s *modelService) checkChatModelUnique(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })

View File

@@ -43,14 +43,14 @@ func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes,
req.WindowSeconds = 3600 // 默认1小时 req.WindowSeconds = 3600 // 默认1小时
} }
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限) // 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel var modelRows []*entity.ModelGatewayModel
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
Where(entity.AsynchModelCol.Enabled, 1). Where(entity.ModelGatewayModelCol.Enabled, 1).
Scan(&modelRows); err != nil { Scan(&modelRows); err != nil {
return nil, err return nil, err
} }
modelMap := make(map[string]*entity.AsynchModel) modelMap := make(map[string]*entity.ModelGatewayModel)
for _, m := range modelRows { for _, m := range modelRows {
if m == nil || m.ModelName == "" { if m == nil || m.ModelName == "" {
continue continue

View File

@@ -2,36 +2,31 @@ package stat
import ( import (
"context" "context"
"model-gateway/model/entity"
"model-gateway/dao" "model-gateway/dao"
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type statService struct{} var ModelGatewayLogsStat = &logsStatService{}
var Stat = &statService{} type logsStatService struct{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { func (s *logsStatService) List(ctx context.Context, req *dto.ListModelStatReq) (*dto.ListModelStatRes, error) {
pageNum, pageSize := 1, 10 if req == nil {
if req != nil { req = &dto.ListModelStatReq{}
if req.PageNum > 0 {
pageNum = req.PageNum
} }
if req.PageSize > 0 { if req.PageNum <= 0 {
pageSize = req.PageSize req.PageNum = 1
} }
if req.PageSize <= 0 {
req.PageSize = 10
} }
startDay, endDay := "", ""
var tenantID *int64 list, total, err := dao.ModelGatewayLogsStat.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayLogsStat{
creator, modelName := "", "" Creator: req.Creator,
if req != nil { ModelName: req.ModelName,
startDay = req.StartDay })
endDay = req.EndDay
tenantID = req.TenantID
creator = req.Creator
modelName = req.ModelName
}
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -17,25 +17,27 @@ import (
"gitea.redpowerfuture.com/red-future/common/utils" "gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid" "github.com/google/uuid"
) )
var Task = &taskService{} var ModelGatewayTask = &taskService{}
type taskService struct{} type taskService struct{}
// Create 创建任务 // Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now() var (
taskID := uuid.NewString() startAt = time.Now()
taskID = uuid.NewString()
)
// 1) 检查模型配置,并且获取模型 // 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx) userInfo, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId, TenantId: userInfo.TenantId,
Creator: userInfo.UserName, Creator: userInfo.UserName,
@@ -66,19 +68,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 异步调用:注入回调地址后提交,拿到 task_id 轮询 // 异步调用:注入回调地址后提交,拿到 task_id 轮询
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl) req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl)
} }
storedPayload := map[string]any{ requestPayload := entity.RequestPayload{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg), Body: req.RequestPayload,
"body": req.RequestPayload, Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
} }
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{ id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
ModelName: req.ModelName, ModelName: req.ModelName,
TaskID: taskID, TaskID: taskID,
State: 0, State: public.TaskStatusPending,
BizName: req.BizName, BizName: req.BizName,
CallbackURL: req.CallbackUrl, CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey, RequestPayload: &requestPayload,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId, EpicycleId: req.EpicycleId,
}) })
if err != nil { // 入库失败:回滚闸门占位 if err != nil { // 入库失败:回滚闸门占位
@@ -97,7 +97,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path apiPath = r.URL.Path
httpMethod = r.Method httpMethod = r.Method
} }
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{ _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip, IP: ip,
UserAgent: ua, UserAgent: ua,
APIPath: apiPath, APIPath: apiPath,
@@ -109,20 +109,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
Success: 1, Success: 1,
ErrorMsg: "", ErrorMsg: "",
CostMs: time.Since(startAt).Milliseconds(), CostMs: time.Since(startAt).Milliseconds(),
RequestPayload: storedPayload, RequestPayload: &requestPayload,
ResponsePayload: gdb.Map{ ResponsePayload: gdb.Map{
"taskId": taskID, "taskId": taskID,
}, },
}) })
// 5) 获取任务信息 // 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务 // 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
@@ -130,10 +127,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil
} }
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.ResultFile.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)
for _, t := range list {
if t == nil {
continue
}
if t.State != public.BuildTypeNode {
continue
}
_ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id)
// 为了本次返回一致性,内存里也更新
t.State = public.TaskStatusDownloaded
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.ResultFile.OssFile,
TextResult: t.TextResult,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) {
if req.PageNum <= 0 {
req.PageNum = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
ModelName: req.ModelName,
BizName: req.BizName,
State: req.State,
TaskID: req.TaskID,
})
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}
// ModelTaskCallback 模型异步任务的回调通知
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) { func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status) g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务 // 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{ task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: req.TaskID, TaskID: req.TaskID,
}) })
if err != nil || task == nil { if err != nil || task == nil {
@@ -167,7 +250,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
} }
// 1. 查 state=1执行中的异步任务 // 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit) tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -176,7 +259,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
var results []dto.QueryTaskItem var results []dto.QueryTaskItem
for _, t := range tasks { for _, t := range tasks {
// 拿到模型配置 // 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil { if err != nil || model == nil || model.QueryConfig == nil {
continue continue
} }
@@ -206,100 +289,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
Results: results, Results: results,
}, nil }, nil
} }
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

View File

@@ -21,6 +21,7 @@ import (
"model-gateway/service/gateway" "model-gateway/service/gateway"
"model-gateway/service/queue" "model-gateway/service/queue"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
@@ -32,11 +33,13 @@ type asyncWorker struct {
} }
// handleOne 执行一次完整的任务 // handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) { func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) // 核心请求参数 var (
maxRetry := model.RetryTimes // 重试次 body = task.RequestPayload.Body // 核心请求参
startTime := time.Now() maxRetry = model.RetryTimes // 重试次数
startTime = time.Now()
modelMessages = map[string]any{}
)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// 1) 分布式并发控制 // 1) 分布式并发控制
@@ -49,8 +52,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
return return
} }
if !acquired { if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Id: task.Id,
},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return return
} }
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
@@ -63,24 +71,24 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig) modelMessages, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil { if err != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
case model.CallMode != nil && *model.CallMode == public.CallModeAsync: case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body) modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil { if err != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg) modelMessages, err = util.PullTaskResult(ctx, modelMessages, model.QueryConfig, model.HeadMsg)
if err != nil { if err != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
default: default:
body, err = w.callModel(ctx, task, model, body) modelMessages, err = w.callModel(ctx, task, model, body)
if err != nil { if err != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
@@ -88,17 +96,20 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
} }
// 3) 保存临时文件 // 3) 保存临时文件
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile) tmpPath, err := util.SaveTempFileByType(task.TaskID, modelMessages, task.TmpFile)
if err == nil && tmpPath != "" { if err == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
} }
// 4) 解析校验 + 响应映射(可重试,失败重新调模型) // 4) 解析校验 + 响应映射(可重试,失败重新调模型)
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime) modelMessages, err = w.parseAndRetry(ctx, modelMessages, task, model, req, maxRetry, startTime)
if err != nil { if err != nil {
task.TextResult = body task.TextResult = modelMessages
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
@@ -116,7 +127,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v",
task.TaskID, attempt, maxRetry, err) task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error()) task.State = 3
task.ErrorMsg = err.Error()
task.Phase = 1
_, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
}
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return return
} }
@@ -125,12 +142,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
// 6) 成功回调 // 6) 成功回调
task.State = 2 task.State = 2
task.DurationSeconds = int64(time.Since(startTime).Seconds()) task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL task.ResultFile = &entity.ResultFile{
task.FileType = oss.FileFormat OssFile: oss.FileAddressPrefix + oss.FileURL,
task.TextResult = body FileType: oss.FileFormat,
task.FileSize = int64(oss.FileSize) FileSize: int64(oss.FileSize),
}
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil { task.TextResult = modelMessages
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return return
} }
@@ -149,7 +167,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
} }
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) { func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
var data []byte var data []byte
var err error var err error
@@ -161,8 +179,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
} }
if data == nil { if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -170,7 +187,10 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
if tmpErr == nil && tmpPath != "" { if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
} }
} }
@@ -186,7 +206,7 @@ type asyncResult struct {
// asyncTaskChan 全局异步任务等待通道 // asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务 // 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body) body, err := w.callModel(ctx, task, model, body)
if err != nil { if err != nil {
@@ -231,7 +251,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
// callModel 调用模型 + 检测文件类型 + 保存临时文件 // callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error // 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
var data []byte var data []byte
var err error var err error
@@ -246,8 +266,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
// 2) 没有可用数据,调用模型 // 2) 没有可用数据,调用模型
if data == nil { if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -258,7 +277,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
if tmpErr == nil && tmpPath != "" { if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
} }
} }
@@ -279,7 +301,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
} }
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -296,10 +318,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
} }
// 2) 先存 token 到数据库,防止后续失败丢失 // 2) 先存 token 到数据库,防止后续失败丢失
if tokens, ok := mapped[model.ResponseTokenField]; ok { if _, ok := mapped[model.ResponseTokenField]; ok {
task.ExpendTokens = gconv.Int64(tokens) task.ExpendTokens = gconv.Int64(mapped[model.ResponseTokenField])
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{ _, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]), SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
}) })
} }
@@ -325,9 +348,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
} }
// 4) 重新调模型(直接调,不走缓存) // 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id) task.RetryCount++
reqBody := util.GetModelBody(task.RequestPayload) _, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey) rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
if callErr != nil { if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue continue
@@ -335,7 +359,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 5) 解析原始响应,覆盖 body 进入下一轮 // 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil { if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue continue
} }
@@ -347,18 +371,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// InvokeModel 调用模型服务,返回二进制结果 // InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 // 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload) //mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2)构建请求 URL 和超时 // 3)构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/") baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout} client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3)构建 HTTP 请求 // 4)构建 HTTP 请求
var req *http.Request var req *http.Request
switch method { switch method {
case http.MethodGet: case http.MethodGet:
@@ -382,31 +409,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
} }
// 4)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者 // 5)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv) req.Header.Set(hk, hv)
} }
if modelKey != "" { if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey) req.Header.Set("Authorization", "Bearer "+model.ApiKey)
} }
if method != http.MethodGet { if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
// 5)发送请求 // 6)发送请求
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
// 6)读取响应体 // 7)读取响应体
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 7)检查 HTTP 状态码 // 8)检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b) msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
@@ -469,7 +496,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
// } // }
// uploadOSS 从临时文件上传 OSS // uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.ModelGatewayTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile) data, err := os.ReadFile(t.TmpFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err) return nil, fmt.Errorf("读取临时文件失败: %w", err)
@@ -479,16 +506,14 @@ func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gat
} }
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) { func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3 t.State = 3
t.ErrorMsg = errMsg t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds()) t.DurationSeconds = int64(time.Since(startTime).Seconds())
_ = dao.Task.UpdateFailedGlobal(ctx, t) _, err := dao.ModelGatewayTask.Update(ctx, t)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t) go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
} }
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}

View File

@@ -109,88 +109,61 @@ COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的
-- ========================= -- =========================
-- 2) asynch_task -- model_gateway_task
-- ========================= -- =========================
CREATE TABLE IF NOT EXISTS asynch_task ( CREATE TABLE model_gateway_task (
-- 基础字段 id int8 PRIMARY KEY,
id BIGINT PRIMARY KEY, -- 主键ID(非自增) tenant_id int8 NOT NULL DEFAULT 0,
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID creator varchar(64) NOT NULL,
creator VARCHAR(64) NOT NULL, -- 创建人 created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间 updater varchar(64) NOT NULL,
updater VARCHAR(64) NOT NULL, -- 更新人 updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间 deleted_at timestamp(6),
deleted_at TIMESTAMP(6), -- 删除时间(软删) model_name varchar(128) NOT NULL,
task_id varchar(64) NOT NULL,
-- 业务字段 biz_name varchar(128) NOT NULL DEFAULT '',
model_name VARCHAR(128) NOT NULL, -- 模型名称 callback_url varchar(512) DEFAULT '',
task_id VARCHAR(64) NOT NULL, -- 任务ID(对外返回) state int2 NOT NULL DEFAULT 0,
biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 业务名称(调用方模块/系统) retry_count int4 NOT NULL DEFAULT 0,
callback_url VARCHAR(512) DEFAULT '', -- 回调地址(可选,用于后续业务通知) phase int2 NOT NULL DEFAULT 0,
model_key VARCHAR(1024) DEFAULT '', -- 动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx tmp_file text DEFAULT '',
state SMALLINT NOT NULL DEFAULT 0, -- 0排队中/1执行中/2成功/3失败/4已下载 error_msg text DEFAULT '',
oss_file VARCHAR(512) DEFAULT '', -- 结果文件OSS地址 result_file jsonb NOT NULL DEFAULT '{}',
file_type VARCHAR(32) DEFAULT '', -- 文件类型(mp3/mp4/png/...) request_payload jsonb NOT NULL DEFAULT '{}',
file_size BIGINT NOT NULL DEFAULT 0, -- 文件大小(字节) text_result jsonb NOT NULL DEFAULT '{}',
error_msg TEXT DEFAULT '', -- 错误信息 expend_tokens int8 NOT NULL DEFAULT 0,
started_at TIMESTAMP, -- 开始执行时间 duration_seconds int8 NOT NULL DEFAULT 0,
finished_at TIMESTAMP, -- 执行结束时间 epicycle_id varchar(64) NOT NULL DEFAULT ''
duration_seconds BIGINT NOT NULL DEFAULT 0, -- 耗时(秒):从创建到完成(成功/失败)整体耗时
expire_at TIMESTAMP, -- state=4 后写入,用于清理
retry_count INT NOT NULL DEFAULT 0, -- 已重试次数(不含首次)
enqueue_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 入队时间(用于排队顺序)
phase SMALLINT NOT NULL DEFAULT 0, -- 0模型阶段/1OSS阶段
tmp_file TEXT DEFAULT '', -- 临时结果文件路径(phase=1 时仅重试 OSS 上传)
input_ref TEXT DEFAULT '', -- 输入引用(如OSS/业务资源ID等)
request_payload JSONB, -- 请求参数(可选)
text_result TEXT DEFAULT '', -- 文本类结果(可选,支持直接回调)
epicycle_id VARCHAR(64) DEFAULT '', -- 轮次ID
expend_tokens BIGINT NOT NULL DEFAULT 0 -- 消耗 token 数
); );
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_task_tenant_task_id ON asynch_task(tenant_id, task_id); CREATE UNIQUE INDEX uk_model_gateway_task_tenant_creator_task_id ON model_gateway_task (tenant_id, creator, task_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_tenant_id ON asynch_task(tenant_id); CREATE INDEX idx_model_gateway_task_task_id ON model_gateway_task (task_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_name ON asynch_task(model_name); CREATE INDEX idx_model_gateway_task_state ON model_gateway_task (state);
CREATE INDEX IF NOT EXISTS idx_asynch_task_biz_name ON asynch_task(biz_name); CREATE INDEX idx_model_gateway_task_deleted_at ON model_gateway_task (deleted_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_key ON asynch_task(model_key);
CREATE INDEX IF NOT EXISTS idx_asynch_task_state ON asynch_task(state);
CREATE INDEX IF NOT EXISTS idx_asynch_task_enqueue_at ON asynch_task(enqueue_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_updated_at ON asynch_task(updated_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expire_at ON asynch_task(expire_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_deleted_at ON asynch_task(deleted_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_epicycle_id ON asynch_task(epicycle_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expend_tokens ON asynch_task(expend_tokens);
COMMENT ON TABLE asynch_task IS '异步任务表'; COMMENT ON TABLE model_gateway_task IS '模型网关任务表';
COMMENT ON COLUMN asynch_task.id IS '主键ID(非自增)'; COMMENT ON COLUMN model_gateway_task.id IS '主键ID';
COMMENT ON COLUMN asynch_task.tenant_id IS '租户ID'; COMMENT ON COLUMN model_gateway_task.tenant_id IS '租户ID';
COMMENT ON COLUMN asynch_task.creator IS '创建人'; COMMENT ON COLUMN model_gateway_task.creator IS '创建人';
COMMENT ON COLUMN asynch_task.created_at IS '创建时间'; COMMENT ON COLUMN model_gateway_task.created_at IS '创建时间';
COMMENT ON COLUMN asynch_task.updater IS '更新人'; COMMENT ON COLUMN model_gateway_task.updater IS '更新人';
COMMENT ON COLUMN asynch_task.updated_at IS '更新时间'; COMMENT ON COLUMN model_gateway_task.updated_at IS '更新时间';
COMMENT ON COLUMN asynch_task.deleted_at IS '删除时间(软删)'; COMMENT ON COLUMN model_gateway_task.deleted_at IS '删除时间软删';
COMMENT ON COLUMN asynch_task.model_name IS '模型名称'; COMMENT ON COLUMN model_gateway_task.model_name IS '模型名称';
COMMENT ON COLUMN asynch_task.task_id IS '任务ID(对外返回)'; COMMENT ON COLUMN model_gateway_task.task_id IS '任务ID对外返回';
COMMENT ON COLUMN asynch_task.biz_name IS '业务名称(调用方模块/系统)'; COMMENT ON COLUMN model_gateway_task.biz_name IS '业务名称调用方模块/系统';
COMMENT ON COLUMN asynch_task.callback_url IS '回调地址(可选,用于后续业务通知)'; COMMENT ON COLUMN model_gateway_task.callback_url IS '回调地址';
COMMENT ON COLUMN asynch_task.model_key IS '动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx'; COMMENT ON COLUMN model_gateway_task.state IS '0排队中/1执行中/2成功/3失败/4已下载';
COMMENT ON COLUMN asynch_task.state IS '0排队中/1执行中/2成功/3失败/4已下载'; COMMENT ON COLUMN model_gateway_task.retry_count IS '已重试次数';
COMMENT ON COLUMN asynch_task.oss_file IS '结果文件OSS地址'; COMMENT ON COLUMN model_gateway_task.phase IS '执行阶段0模型阶段/1OSS阶段';
COMMENT ON COLUMN asynch_task.file_type IS '文件类型(mp3/mp4/png/...)'; COMMENT ON COLUMN model_gateway_task.tmp_file IS '临时结果文件路径';
COMMENT ON COLUMN asynch_task.file_size IS '文件大小(字节)'; COMMENT ON COLUMN model_gateway_task.error_msg IS '错误信息';
COMMENT ON COLUMN asynch_task.error_msg IS '错误信息'; COMMENT ON COLUMN model_gateway_task.result_file IS '结果文件:{oss_file, file_type, file_size}';
COMMENT ON COLUMN asynch_task.started_at IS '开始执行时间'; COMMENT ON COLUMN model_gateway_task.request_payload IS '请求参数JSON';
COMMENT ON COLUMN asynch_task.finished_at IS '执行结束时间'; COMMENT ON COLUMN model_gateway_task.text_result IS '文本类结果';
COMMENT ON COLUMN asynch_task.duration_seconds IS '耗时(秒):从创建到完成(成功/失败)整体耗时'; COMMENT ON COLUMN model_gateway_task.expend_tokens IS '消耗token数';
COMMENT ON COLUMN asynch_task.expire_at IS 'state=4 后写入,用于清理'; COMMENT ON COLUMN model_gateway_task.duration_seconds IS '耗时(秒)';
COMMENT ON COLUMN asynch_task.retry_count IS '已重试次数(不含首次)'; COMMENT ON COLUMN model_gateway_task.epicycle_id IS '轮次ID';
COMMENT ON COLUMN asynch_task.enqueue_at IS '入队时间(用于排队顺序)';
COMMENT ON COLUMN asynch_task.phase IS '执行阶段 模型阶段/1OSS阶段(模型已成功,等待上传OSS)';
COMMENT ON COLUMN asynch_task.tmp_file IS '临时结果文件路径(phase=1 时仅重试 OSS 上传)';
COMMENT ON COLUMN asynch_task.input_ref IS '输入引用(如OSS/业务资源ID等)';
COMMENT ON COLUMN asynch_task.request_payload IS '请求参数(可选,JSON)';
COMMENT ON COLUMN asynch_task.text_result IS '文本类结果(可选,支持直接回调)';
COMMENT ON COLUMN asynch_task.epicycle_id IS '轮次ID(用于标识同一轮次的任务)';
COMMENT ON COLUMN asynch_task.expend_tokens IS '消耗 token 数';