Compare commits

4 Commits

36 changed files with 1279 additions and 1761 deletions

View File

@@ -13,7 +13,6 @@ import (
"strings" "strings"
"time" "time"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
@@ -21,15 +20,15 @@ import (
) )
// ParseAndValidate 解析并校验结果 // ParseAndValidate 解析并校验结果
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) { func ParseAndValidate(raw map[string]any, model *entity.ModelGatewayModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组 // 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody] contentVal, ok := raw[entity.ResponseBody]
if !ok { if !ok {
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody) return raw, fmt.Errorf("字段 %s 不存在", entity.ResponseBody)
} }
contentStr, ok := contentVal.(string) contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" { if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody) return raw, fmt.Errorf("字段 %s 为空或不是字符串", entity.ResponseBody)
} }
var arr []any var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil { if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
@@ -94,53 +93,6 @@ func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
} }
} }
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取 rounds
roundsRaw, ok := raw["rounds"]
if !ok {
return fmt.Errorf("缺少 rounds 字段")
}
rounds, ok := roundsRaw.([]any)
if !ok {
return fmt.Errorf("rounds 不是数组")
}
if len(rounds) == 0 {
return fmt.Errorf("rounds 数组为空")
}
// 2) 没有配置必填字段,跳过
if len(model.RequiredFields) == 0 {
return nil
}
// 3) 逐条校验
for i, r := range rounds {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
return nil
}
// validateRequiredFields 校验单个 round 对象的必选字段
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
for _, field := range requiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
}
}
return nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头 // ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例: // head_msg 格式示例:
// //
@@ -198,16 +150,17 @@ func MapResponsePayload(mapping map[string]any, result map[string]any) (map[stri
return mapped, nil return mapped, nil
} }
// GetModelBody 获取数据库中保存的模型信息 //
func GetModelBody(v map[string]any) map[string]any { //// GetModelBody 获取数据库中保存的模型信息
if v == nil { //func GetModelBody(v map[string]any) map[string]any {
return nil // if v == nil {
} // return nil
if p, ok := v["body"]; ok { // }
return gconv.Map(p) // if p, ok := v["body"]; ok {
} // return gconv.Map(p)
return v // }
} // return v
//}
// BodyToQuery 将 body 转为 url.Values // BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) { func BodyToQuery(payload map[string]any) (url.Values, error) {
@@ -348,12 +301,3 @@ func replaceURLParams(url string, params map[string]any) string {
return s return s
}) })
} }
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
return payload
}

View File

@@ -6,6 +6,14 @@ const (
CallModeStream = 2 // 流式调用 CallModeStream = 2 // 流式调用
) )
const (
TaskStatusPending = 0 // 排队中
TaskStatusRunning = 1 // 执行中
TaskStatusSuccess = 2 // 成功
TaskStatusFailed = 3 // 失败
TaskStatusDownloaded = 4 // 已下载
)
const ( const (
BuildTypePrompt = 1 //提示词构建 BuildTypePrompt = 1 //提示词构建
BuildTypeNode = 2 //节点构建 BuildTypeNode = 2 //节点构建

View File

@@ -5,8 +5,8 @@ const (
) )
const ( const (
TableNameModel = "asynch_models" // 模型表 TableNameModel = "model_gateway_models" // 模型表
TableNameTask = "asynch_task" // 任务表 TableNameTask = "model_gateway_task" // 任务表
TableNameOpLog = "logs_model_op" // 操作日志表 TableNameOpLog = "model_gateway_logs_op" // 操作日志表
TableNameStat = "logs_model_stat" // 按天统计表(请求次数) TableNameStat = "model_gateway_logs_stat" // 按天统计表
) )

View File

@@ -7,12 +7,12 @@ import (
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type stat struct{} // ModelGatewayLogsStat 统计控制器
var ModelGatewayLogsStat = new(stat)
// Stat 统计控制器 type stat struct{}
var Stat = new(stat)
// ListModelStat 统计列表 // ListModelStat 统计列表
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
return statService.Stat.List(ctx, req) return statService.ModelGatewayLogsStat.List(ctx, req)
} }

View File

@@ -7,36 +7,36 @@ import (
"model-gateway/service/queue" "model-gateway/service/queue"
) )
type model struct{} // ModelGatewayModels 模型配置控制器
var ModelGatewayModels = new(model)
// Model 模型配置控制器 type model struct{}
var Model = new(model)
// CreateModel 添加配置 // CreateModel 添加配置
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
return modelService.Model.Create(ctx, req) return modelService.ModelGatewayModels.Create(ctx, req)
} }
// UpdateModel 更改配置 // UpdateModel 更改配置
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) { func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
err = modelService.Model.Update(ctx, req) err = modelService.ModelGatewayModels.Update(ctx, req)
return return
} }
// DeleteModel 删除配置 // DeleteModel 删除配置
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) { func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
err = modelService.Model.Delete(ctx, req) err = modelService.ModelGatewayModels.Delete(ctx, req)
return return
} }
// GetModel 获取配置详情 // GetModel 获取配置详情
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) { func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
return modelService.Model.Get(ctx, req) return modelService.ModelGatewayModels.Get(ctx, req)
} }
// ListModel 配置列表 // ListModel 配置列表
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
return modelService.Model.List(ctx, req) return modelService.ModelGatewayModels.List(ctx, req)
} }
// AutoTune 动态调参(由上层定时任务每小时触发一次) // AutoTune 动态调参(由上层定时任务每小时触发一次)
@@ -56,11 +56,11 @@ func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res
// UpdateChatModel 更新是否为聊天模型 // UpdateChatModel 更新是否为聊天模型
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) { func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
err = modelService.Model.UpdateChatModel(ctx, req) err = modelService.ModelGatewayModels.UpdateChatModel(ctx, req)
return return
} }
// GetIsChatModel 获取当前会话模型 // GetIsChatModel 获取当前会话模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) { func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
return modelService.Model.GetIsChatModel(ctx) return modelService.ModelGatewayModels.GetIsChatModel(ctx)
} }

View File

@@ -2,48 +2,42 @@ package controller
import ( import (
"context" "context"
"model-gateway/service/job"
taskService "model-gateway/service/task" taskService "model-gateway/service/task"
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type task struct{} // ModelGatewayTask 任务控制器
var ModelGatewayTask = new(task)
// Task 任务控制器 type task struct{}
var Task = new(task)
// CreateTask 根据 modelName 创建异步任务,返回 taskId // CreateTask 根据 modelName 创建异步任务,返回 taskId
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
return taskService.Task.Create(ctx, req) return taskService.ModelGatewayTask.Create(ctx, req)
} }
// ModelTaskCallback 接收模型异步任务的回调通知 // GetTaskResult 获取单条任务结果(返回 *dto.GetTaskResultRes
func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return taskService.Task.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.Task.QueryPendingTasks(ctx, req)
}
// GetTaskResult 获取任务结果(只返回 oss 地址 + state
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) { func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
return taskService.Task.GetResult(ctx, req.TaskID) return taskService.ModelGatewayTask.GetResult(ctx, req.TaskID)
} }
// GetTaskBatch 批量查询任务(成功任务标记为已下载 // GetTaskBatch 批量查询任务(返回 *[]dto.GetTaskBatchItem
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
return taskService.Task.GetBatch(ctx, req) return taskService.ModelGatewayTask.GetBatch(ctx, req)
} }
// ListTask 任务列表分页查询 // ListTask 任务列表分页查询
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
return taskService.Task.List(ctx, req) return taskService.ModelGatewayTask.List(ctx, req)
} }
// CleanWork 手动触发一次 cleaner由上层定时任务调用 // ModelTaskCallback 接收模型异步任务的回调通知 —— 待调整
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) { func (c *task) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (res *dto.ModelTaskCallbackRes, err error) {
return job.Cleaner.RunOnce(ctx) return taskService.ModelGatewayTask.ModelTaskCallback(ctx, req)
}
// QueryPendingTasks 批量轮询进行中的异步任务 —— 待调整
func (c *task) QueryPendingTasks(ctx context.Context, req *dto.QueryPendingTasksReq) (res *dto.QueryPendingTasksRes, err error) {
return taskService.ModelGatewayTask.QueryPendingTasks(ctx, req)
} }

View File

@@ -0,0 +1,23 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var ModelGatewayLogsOp = &modelGatewayLogsOpDao{}
type modelGatewayLogsOpDao struct{}
// Insert 插入操作日志
func (d *modelGatewayLogsOpDao) Insert(ctx context.Context, req *entity.ModelGatewayLogsOp) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).Insert(req)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -0,0 +1,52 @@
package dao
import (
"context"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayLogsStat = &modelGatewayLogsStatDao{}
type modelGatewayLogsStatDao struct{}
// IncRequestCount 原子累加:按天+租户+创建人+模型 +1
func (d *modelGatewayLogsStatDao) IncRequestCount(ctx context.Context, day time.Time, tenantId uint64, creator, modelName string) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
Data(&entity.ModelGatewayLogsStat{
Day: gtime.New(day),
TenantId: tenantId,
Creator: creator,
ModelName: modelName,
RequestCount: 1,
}).
OnDuplicate("request_count", "request_count+1").
Insert()
return err
}
// List 分页查询统计
func (d *modelGatewayLogsStatDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayLogsStat) (list []*entity.ModelGatewayLogsStat, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
OmitEmpty().
Where(entity.ModelGatewayLogsStatCols.Creator, req.Creator).
WhereLike(entity.ModelGatewayLogsStatCols.ModelName, "%"+req.ModelName+"%").
OrderDesc(entity.ModelGatewayLogsStatCols.Day).
OrderDesc(entity.ModelGatewayLogsStatCols.RequestCount)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -9,71 +9,64 @@ import (
"gitea.redpowerfuture.com/red-future/common/db/gfdb" "gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
) )
var Model = &modelDao{} var ModelGatewayModels = &modelGatewayModelsDao{}
type modelDao struct{} type modelGatewayModelsDao struct{}
// Insert 插入 // Insert 插入
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) { func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
m := new(entity.AsynchModel) r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
err = gconv.Struct(req, &m)
if err != nil { if err != nil {
return return 0, err
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
Insert(m)
if err != nil {
return
} }
return r.LastInsertId() return r.LastInsertId()
} }
// Update 更新 // Update 更新
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Data(&req). Data(req).
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Update() Update()
if err != nil { if err != nil {
return return 0, err
} }
return r.RowsAffected() return r.RowsAffected()
} }
// Delete 删除 // Delete 删除
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Delete() Delete()
if err != nil { if err != nil {
return return 0, err
} }
return r.RowsAffected() return r.RowsAffected()
} }
// Get 获取模型 // Get 获取模型
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return return nil, err
} }
var m entity.ModelGatewayModel
err = r.Struct(&m) err = r.Struct(&m)
return return &m, err
} }
//// Get 按ID获取带租户隔离只查当前租户 //// Get 按ID获取带租户隔离只查当前租户
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { //func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
// var whereCondition strings.Builder // var whereCondition strings.Builder
// var queryParams []interface{} // var queryParams []interface{}
// if !g.IsEmpty(req.Id) { // if !g.IsEmpty(req.Id) {
@@ -108,25 +101,25 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
// return // return
//} //}
// GetByAcrossTenant 按ID获取跨租户查所有租户 // GetByAcrossTenant 跨租户查询
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
NoTenantId(ctx). NoTenantId(ctx).
OmitEmpty(). OmitEmpty().
Where(entity.AsynchModelCol.Id, req.Id). Where(entity.ModelGatewayModelCol.Id, req.Id).
Where(entity.AsynchModelCol.Creator, req.Creator). Where(entity.ModelGatewayModelCol.Creator, req.Creator).
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
Where(entity.AsynchModelCol.ModelName, req.ModelName).
Fields(fields).One() Fields(fields).One()
if err != nil { if err != nil {
return return nil, err
} }
var m entity.ModelGatewayModel
err = r.Struct(&m) err = r.Struct(&m)
return return &m, err
} }
// GetByCreatorAndPlatform 按创建者、平台获取 // GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
sql := ` sql := `
SELECT DISTINCT ON (model_name) * SELECT DISTINCT ON (model_name) *
FROM asynch_models FROM asynch_models
@@ -186,7 +179,7 @@ WHERE deleted_at IS NULL
} }
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文 // GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) { func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1", "SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
tenantId, modelName, tenantId, modelName,
@@ -197,7 +190,7 @@ func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64,
if r.IsEmpty() { if r.IsEmpty() {
return nil, nil return nil, nil
} }
var list []*entity.AsynchModel var list []*entity.ModelGatewayModel
if err := r.Structs(&list); err != nil { if err := r.Structs(&list); err != nil {
return nil, err return nil, err
} }

View File

@@ -0,0 +1,159 @@
package dao
import (
"context"
"fmt"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelGatewayTask = &modelGatewayTaskDao{}
type modelGatewayTaskDao struct{}
// Insert 插入
func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) {
m := new(entity.ModelGatewayTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新按ID
func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(req).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取按TaskID 或 ID
func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// List 分页查询
func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) {
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.ModelGatewayTaskCol.Creator, req.Creator).
Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%").
Where(entity.ModelGatewayTaskCol.BizName, req.BizName).
Where(entity.ModelGatewayTaskCol.State, req.State).
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
OrderDesc(entity.ModelGatewayTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
model = model.Page(pageNum, pageSize)
}
r, totalInt, err := model.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// Delete 删除软删按ID
func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, req.Id).
Delete()
if err != nil {
return
}
return r.RowsAffected()
}
// ListByTaskIDs 批量查询
func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&list)
return
}
// MarkDownloadedByID 标记已下载
func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error {
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, 2).
Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}).
Update()
return err
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) {
var tasks []*entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.ModelGatewayTaskCol.State, 1).
Limit(limit).
Scan(&tasks)
return tasks, err
}
// ======================== 事务抢占 ========================
// ClaimByID 按主键抢占,返回抢占后的任务
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
var task entity.ModelGatewayTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
r, err := tx.Model(public.TableNameTask).
Where(entity.ModelGatewayTaskCol.Id, id).
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
Limit(1).
LockUpdate().
One()
if err != nil {
return err
}
if r.IsEmpty() {
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
}
if err := r.Struct(&task); err != nil {
return err
}
_, err = tx.Model(public.TableNameTask).
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
Where(entity.ModelGatewayTaskCol.Id, id).
OmitEmpty().
Update()
return err
})
if err != nil {
return nil, err
}
return &task, nil
}

View File

@@ -1,30 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
type opLogDao struct{}
var OpLog = &opLogDao{}
// Insert 插入
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
m := new(entity.LogsModelOp)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
Insert(m)
if err != nil {
return 0, err
}
return r.LastInsertId()
}

View File

@@ -1,60 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/os/gtime"
)
type statDao struct{}
var Stat = &statDao{}
// IncRequestCount 原子累加(支持分布式/多协程):按天+租户+创建人+模型 +1
func (d *statDao) IncRequestCount(ctx context.Context, day time.Time, tenantId int64, creator, modelName string) error {
sql := fmt.Sprintf(`
INSERT INTO %s(day, tenant_id, creator, model_name, request_count, created_at, updated_at)
VALUES(?, ?, ?, ?, 1, NOW(), NOW())
ON CONFLICT (day, tenant_id, creator, model_name)
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
public.TableNameStat, public.TableNameStat,
)
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
return err
}
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) {
m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1")
if startDay != "" {
m = m.Where("day >= ?", startDay)
}
if endDay != "" {
m = m.Where("day <= ?", endDay)
}
if tenantId != nil {
m = m.Where("tenant_id = ?", *tenantId)
}
if creator != "" {
m = m.WhereLike("creator", "%"+creator+"%")
}
if modelName != "" {
m = m.WhereLike("model_name", "%"+modelName+"%")
}
m = m.OrderDesc("day").OrderDesc("request_count")
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = int64(totalInt)
err = r.Structs(&list)
return
}

View File

@@ -1,124 +0,0 @@
package dao
import (
"context"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = &taskDao{}
type taskDao struct{}
// Insert 插入
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
m := new(entity.AsynchTask)
err = gconv.Struct(req, &m)
if err != nil {
return
}
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
Insert(m)
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Data(&req).
Where(entity.AsynchTaskCol.Id, req.Id).
Update()
if err != nil {
return
}
return r.RowsAffected()
}
// Get 获取
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
Fields(fields).One()
if err != nil {
return
}
err = r.Struct(&m)
return
}
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
if len(taskIDs) == 0 {
return nil, nil
}
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
OmitEmpty().
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
All()
if err != nil {
return nil, err
}
err = r.Structs(&m)
return
}
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
data := gdb.Map{
entity.AsynchTaskCol.State: 4,
entity.AsynchTaskCol.ExpireAt: expireAt,
entity.AsynchTaskCol.Updater: "",
}
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where(entity.AsynchTaskCol.Id, id).
Where(entity.AsynchTaskCol.State, 2).
Data(data).
Update()
return err
}
// List 任务分页查询(受 gfdb 租户 Hook 影响)
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
if modelNameLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
}
if taskIDLike != "" {
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
}
if state != nil {
m = m.Where(entity.AsynchTaskCol.State, *state)
}
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
if pageNum > 0 && pageSize > 0 {
m = m.Page(pageNum, pageSize)
}
r, totalInt, err := m.AllAndCount(false)
if err != nil {
return nil, 0, err
}
total = gconv.Int64(totalInt)
err = r.Structs(&list)
return
}
// GetPendingAsyncTasks 获取进行中的异步任务
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
Where("state", 1).
Where("deleted_at IS NULL").
Limit(limit).
Scan(&tasks)
return tasks, err
}

View File

@@ -1,214 +0,0 @@
package dao
import (
"context"
"fmt"
"time"
"model-gateway/consts/public"
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/os/gtime"
)
// ======================== 查询辅助 ========================
// taskColumns 查询用的公共字段
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
// ======================== 事务抢占 ========================
// claimTasks 事务内抢占任务并更新 state=1
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, where)
r, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
var task entity.AsynchTask
if err := r.Struct(&task); err != nil {
return err
}
now := time.Now()
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, task.Id)
if err != nil {
return err
}
tasks = []*entity.AsynchTask{&task}
return nil
})
return tasks, err
}
// ClaimPendingGlobal 批量抢占 pending 任务
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
if batchSize <= 0 {
batchSize = 1
}
var tasks []*entity.AsynchTask
err := gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
sql := fmt.Sprintf(`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`, taskColumns, public.TableNameTask, batchSize)
r, err := tx.GetAll(sql)
if err != nil {
return err
}
if r.IsEmpty() {
return nil
}
if err := r.Structs(&tasks); err != nil {
return err
}
now := time.Now()
for _, t := range tasks {
_, err = tx.Exec(fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), now, now, t.Id)
if err != nil {
return err
}
}
return nil
})
return tasks, err
}
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
if taskID == "" {
return nil, nil
}
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
if err != nil || len(tasks) == 0 {
return nil, err
}
return tasks[0], nil
}
// ======================== 更新辅助 ========================
func execSQL(ctx context.Context, sql string, args ...any) error {
_, err := gfdb.DB(ctx).Exec(ctx, sql, args...)
return err
}
// updateTask 通用更新
func updateTask(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).Data(data).Update()
return err
}
// UpdateSuccessGlobal 更新任务成功
func (d *taskDao) UpdateSuccessGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 2,
OssFile: t.OssFile,
FileType: t.FileType,
TextResult: t.TextResult,
FileSize: t.FileSize,
ErrorMsg: "",
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
ExpendTokens: t.ExpendTokens,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedGlobal 模型调用失败
func (d *taskDao) UpdateFailedGlobal(ctx context.Context, t *entity.AsynchTask) error {
return updateTask(ctx, t.Id, entity.AsynchTask{
State: 3,
ErrorMsg: t.ErrorMsg,
FinishedAt: gtime.Now(),
Phase: 0,
TmpFile: "",
TextResult: t.TextResult,
DurationSeconds: t.DurationSeconds,
})
}
// UpdateFailedKeepTmpGlobal OSS 上传失败
func (d *taskDao) UpdateFailedKeepTmpGlobal(ctx context.Context, id int64, errorMsg string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=3, error_msg=?, finished_at=?, phase=1, updated_at=? WHERE id=?`, public.TableNameTask), errorMsg, gtime.Now(), gtime.Now(), id)
}
// UpdateTmpAfterModelGlobal 写临时文件
func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFile string) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET phase=1, tmp_file=?, updated_at=NOW() WHERE id=?`, public.TableNameTask), tmpFile, id)
}
// RollbackToPendingGlobal 回滚
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), id)
}
// IncRetryCountGlobal 重试计数+1
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask), id)
}
// RequeueForRetryGlobal 重新入队
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
return execSQL(ctx, fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask), enqueueAt, id)
}
// ======================== 列表查询 ========================
// ListExpiredDownloadedGlobal
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask), gtime.Now(), clampLimit(limit, 200))
}
// ListFailedRetryableGlobal
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListFailedExhaustedGlobal
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// ListTimeoutTasksGlobal
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
return queryTasks(ctx, fmt.Sprintf(`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`, public.TableNameTask, public.TableNameModel), clampLimit(limit, 200))
}
// HardDeleteByIDGlobal
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
return execSQL(ctx, fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask), id)
}
// ======================== 内部辅助 ========================
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
if err != nil {
return nil, err
}
var list []*entity.AsynchTask
err = r.Structs(&list)
return list, err
}
func clampLimit(limit, defaultVal int) int {
if limit <= 0 {
return defaultVal
}
return limit
}
// UpdateColumns 更新指定字段(结构体版)
func (d *taskDao) UpdateColumns(ctx context.Context, id int64, data entity.AsynchTask) error {
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.AsynchTaskCol.Id, id).
Data(data).
Update()
return err
}

29
main.go
View File

@@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"model-gateway/model/dto" "model-gateway/model/dto"
"model-gateway/service/job"
"model-gateway/service/task" "model-gateway/service/task"
"os" "os"
"os/signal" "os/signal"
@@ -27,9 +26,9 @@ func main() {
// 注册路由 // 注册路由
http.RouteRegister([]interface{}{ http.RouteRegister([]interface{}{
controller.Model, controller.ModelGatewayModels,
controller.Task, controller.ModelGatewayTask,
controller.Stat, controller.ModelGatewayLogsStat,
}) })
// 本地调试:可选自动触发 worker/cleaner由配置文件控制 // 本地调试:可选自动触发 worker/cleaner由配置文件控制
@@ -47,26 +46,6 @@ func main() {
} }
func startAutoRunner(ctx context.Context) { func startAutoRunner(ctx context.Context) {
// cleaner
if g.Cfg().MustGet(ctx, "asynch.cleaner.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.cleaner.intervalSeconds").Int()
if interval <= 0 {
interval = 30
}
ticker := time.NewTicker(time.Duration(interval) * time.Second)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_, _ = job.Cleaner.RunOnce(ctx)
}
}
}()
}
// queryPending // queryPending
if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() { if g.Cfg().MustGet(ctx, "asynch.queryPending.enabled").Bool() {
interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int() interval := g.Cfg().MustGet(ctx, "asynch.queryPending.intervalSeconds", 10).Int()
@@ -79,7 +58,7 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
if _, err := task.Task.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil { if _, err := task.ModelGatewayTask.QueryPendingTasks(ctx, &dto.QueryPendingTasksReq{Limit: limit}); err != nil {
g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err) g.Log().Warningf(ctx, "[auto-queryPending] run once failed: %v", err)
} }
} }

View File

@@ -1,190 +0,0 @@
package dto
import (
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
// CreateModelReq 添加模型配置
type CreateModelReq struct {
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数默认10"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数默认3"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间默认86400"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type CreateModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
type UpdateModelReq struct {
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody string `p:"responseBody" json:"responseBody" dc:"返回主体"`
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type UpdateModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// DeleteModelReq 删除模型配置
type DeleteModelReq struct {
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
type DeleteModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// GetModelReq 获取模型配置详情
type GetModelReq struct {
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
}
type GetModelRes struct {
Model *entity.AsynchModel `json:"model" dc:"模型配置详情"`
}
// ListModelReq 配置列表
type ListModelReq struct {
g.Meta `path:"/listModel" method:"get" tags:"模型管理" summary:"模型配置列表" dc:"分页获取模型配置列表"`
Page *beans.Page `json:"page"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊查询,可选)"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-禁用1-启用"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化 0-私有 1-公共"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者 0-否 1-是"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
}
type ListModelRes struct {
List any `json:"list" dc:"列表数据"`
Total int `json:"total" dc:"总数"`
}
// AutoTuneReq 动态调参(由上层定时任务每小时触发一次)
type AutoTuneReq struct {
g.Meta `path:"/autoTune" method:"post" tags:"模型管理" summary:"动态调参" dc:"按 model_name 维度统计指定时间窗口内执行耗时(P90),动态生成运行时 max_concurrency/queue_limit不超过配置上限写入 Redis 供 Worker/CreateTask 使用windowSeconds 不传默认 3600"`
WindowSeconds int `p:"windowSeconds" json:"windowSeconds" dc:"统计窗口秒数;不传/<=0 默认 36001小时"`
}
type AutoTuneRes struct {
List any `json:"list" dc:"调参结果列表"`
}
type ModelTypeModelItem struct {
ID int64 `json:"id" dc:"模型主键ID"`
Name string `json:"name" dc:"模型名称"`
Form any `json:"form" dc:"动态表单配置JSON数组用于前端渲染"`
}
// ListModelTypeReq 模型类型列表(分页)
type ListTypeReq struct {
g.Meta `path:"/listType" method:"get" tags:"模型类型列表" summary:"模型类型列表" dc:"分页获取模型类型列表"`
}
type TypeItem struct {
Type map[int]string `json:"type" dc:"模型类型ID到名称的映射"`
}
type ListOperatorReq struct {
g.Meta `path:"/listOperator" method:"get" tags:"模型管理" summary:"获取运营商列表" dc:"获取运营商列表"`
}
type ListOperatorRes struct {
List []string `json:"list" dc:"运营商名称到ID的映射"`
}
type UpdateChatModelReq struct {
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
}
type UpdateChatModelRes struct {
ID int64 `json:"id,string" dc:"模型ID"`
}
type GetIsChatModelReq struct {
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
}
type GetIsChatModelRes struct {
Model any `json:"model" dc:"模型详情"`
}
// NodeFormField 节点表单
type NodeFormField struct {
Value any `json:"value" dc:"字段值"`
Field string `json:"field" dc:"字段标识"`
Label string `json:"label" dc:"字段标签"`
Type string `json:"type" dc:"字段类型"`
Required bool `json:"required" dc:"是否必填"`
Default any `json:"default,omitempty" dc:"默认值"`
Options []SelectOption `json:"options" dc:"下拉选项列表"`
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
}
type SelectOption struct {
Label string `json:"label" dc:"选项标签"`
Value string `json:"value" dc:"选项值"`
}

View File

@@ -0,0 +1,20 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -0,0 +1,186 @@
package dto
import (
"model-gateway/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
// CreateModelReq 添加模型配置
type CreateModelReq struct {
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
ModelName string `p:"modelName" json:"modelName" v:"required#模型名称不能为空" dc:"模型名称(唯一标识)"`
ModelType int `p:"modelType" json:"modelType" v:"required#模型类型不能为空" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#模型地址不能为空" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST默认POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数默认10"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间默认600"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数默认3"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间默认86400"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type CreateModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
type UpdateModelReq struct {
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务地址"`
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST"`
HeadMsg map[string]any `p:"headMsg" json:"headMsg" dc:"请求头JSON结构"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化0-私有 1-公共"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-停用 1-启用"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为对话模型0-否 1-是"`
CallModel *int `p:"callModel" json:"callModel" dc:"调用模式0-同步 1-异步 2-流式"`
RequiredFields []string `p:"requiredFields" json:"requiredFields" dc:"必填字段"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者0-否 1-是"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
OperatorName string `p:"operatorName" json:"operatorName" dc:"运营商名称"`
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"查询/回调配置"`
StreamConfig map[string]any `p:"streamConfig" json:"streamConfig" dc:"流式输出配置"`
FirstFrame string `p:"firstFrame" json:"firstFrame" dc:"首帧图片参数"`
LastFrame string `p:"lastFrame" json:"lastFrame" dc:"尾帧图片参数"`
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数"`
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)"`
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数"`
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
}
type UpdateModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// DeleteModelReq 删除模型配置
type DeleteModelReq struct {
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
}
type DeleteModelRes struct {
ID int64 `json:"id,string" dc:"配置ID"`
}
// GetModelReq 获取模型配置详情
type GetModelReq struct {
g.Meta `path:"/getModel" method:"get" tags:"模型管理" summary:"获取模型配置" dc:"根据模型ID获取配置详情"`
ID int64 `p:"id" json:"id,string" dc:"配置ID"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
IsChatModel *int `p:"isChatModel" json:"isChatModel" dc:"是否为聊天模型"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
}
type GetModelRes struct {
Model *entity.ModelGatewayModel `json:"model" dc:"模型配置详情"`
}
// ListModelReq 配置列表
type ListModelReq struct {
g.Meta `path:"/listModel" method:"get" tags:"模型管理" summary:"模型配置列表" dc:"分页获取模型配置列表"`
Page *beans.Page `json:"page"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊查询,可选)"`
ModelType int `p:"modelType" json:"modelType" dc:"模型类型"`
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用0-禁用1-启用"`
IsPrivate *int `p:"isPrivate" json:"isPrivate" dc:"是否私有化 0-私有 1-公共"`
IsOwner *int `p:"isOwner" json:"isOwner" dc:"是否为所有者 0-否 1-是"`
Creator string `p:"creator" json:"creator" dc:"创建人"`
}
type ListModelRes struct {
List any `json:"list" dc:"列表数据"`
Total int `json:"total" dc:"总数"`
}
// AutoTuneReq 动态调参(由上层定时任务每小时触发一次)
type AutoTuneReq struct {
g.Meta `path:"/autoTune" method:"post" tags:"模型管理" summary:"动态调参" dc:"按 model_name 维度统计指定时间窗口内执行耗时(P90),动态生成运行时 max_concurrency/queue_limit不超过配置上限写入 Redis 供 Worker/CreateTask 使用windowSeconds 不传默认 3600"`
WindowSeconds int `p:"windowSeconds" json:"windowSeconds" dc:"统计窗口秒数;不传/<=0 默认 36001小时"`
}
type AutoTuneRes struct {
List any `json:"list" dc:"调参结果列表"`
}
type ModelTypeModelItem struct {
ID int64 `json:"id" dc:"模型主键ID"`
Name string `json:"name" dc:"模型名称"`
Form any `json:"form" dc:"动态表单配置JSON数组用于前端渲染"`
}
// ListModelTypeReq 模型类型列表(分页)
type ListTypeReq struct {
g.Meta `path:"/listType" method:"get" tags:"模型类型列表" summary:"模型类型列表" dc:"分页获取模型类型列表"`
}
type TypeItem struct {
Type map[int]string `json:"type" dc:"模型类型ID到名称的映射"`
}
type ListOperatorReq struct {
g.Meta `path:"/listOperator" method:"get" tags:"模型管理" summary:"获取运营商列表" dc:"获取运营商列表"`
}
type ListOperatorRes struct {
List []string `json:"list" dc:"运营商名称到ID的映射"`
}
type UpdateChatModelReq struct {
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
}
type UpdateChatModelRes struct {
ID int64 `json:"id,string" dc:"模型ID"`
}
type GetIsChatModelReq struct {
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
}
type GetIsChatModelRes struct {
Model any `json:"model" dc:"模型详情"`
}
// NodeFormField 节点表单
type NodeFormField struct {
Value any `json:"value" dc:"字段值"`
Field string `json:"field" dc:"字段标识"`
Label string `json:"label" dc:"字段标签"`
Type string `json:"type" dc:"字段类型"`
Required bool `json:"required" dc:"是否必填"`
Default any `json:"default,omitempty" dc:"默认值"`
Options []SelectOption `json:"options" dc:"下拉选项列表"`
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
}
type SelectOption struct {
Label string `json:"label" dc:"选项标签"`
Value string `json:"value" dc:"选项值"`
}

View File

@@ -1,6 +1,8 @@
package dto package dto
import "github.com/gogf/gf/v2/frame/g" import (
"github.com/gogf/gf/v2/frame/g"
)
// CreateTaskReq 创建异步任务 // CreateTaskReq 创建异步任务
type CreateTaskReq struct { type CreateTaskReq struct {
@@ -8,7 +10,6 @@ type CreateTaskReq struct {
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"` ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"` BizName string `p:"bizName" json:"bizName" dc:"业务名称(调用方模块/系统,用于统计)"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"` CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址(可选,用于后续业务通知)"`
InputRef string `p:"inputRef" json:"inputRef" dc:"输入引用如OSS/文件引用等)"`
RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"` RequestPayload map[string]any `p:"requestPayload" json:"requestPayload" dc:"请求负载(透传给模型服务)"`
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"` EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"` BuildType int64 `json:"buildType" dc:"构建类型1-提示词构建 2-节点构建"`
@@ -67,14 +68,15 @@ type GetTaskBatchReq struct {
TaskIDs []string `p:"taskIds" json:"taskIds" v:"required#taskIds不能为空" dc:"任务ID列表"` TaskIDs []string `p:"taskIds" json:"taskIds" v:"required#taskIds不能为空" dc:"任务ID列表"`
} }
type GetTaskBatchRes struct {
List []GetTaskBatchItem `json:"list" dc:"任务列表"`
}
type GetTaskBatchItem struct { type GetTaskBatchItem struct {
TaskID string `json:"taskId" dc:"任务ID"` TaskID string `json:"taskId" dc:"任务ID"`
State int `json:"state" dc:"任务状态"` State int `json:"state" dc:"任务状态"`
OssFile string `json:"ossFile" dc:"结果文件OSS地址"` OssFile string `json:"ossFile" dc:"结果文件OSS地址"`
} TextResult map[string]any `json:"textResult" dc:"文本结果"`
type GetTaskBatchRes struct {
List []GetTaskBatchItem `json:"list" dc:"任务列表"`
} }
// ListTaskReq 任务列表分页查询 // ListTaskReq 任务列表分页查询
@@ -83,8 +85,9 @@ type ListTaskReq struct {
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"` PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"` PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"` ModelName string `p:"modelName" json:"modelName" dc:"模型名称(模糊匹配)"`
BizName string `p:"bizName" json:"bizName" dc:"业务名称"`
TaskID string `p:"taskId" json:"taskId" dc:"任务ID模糊匹配"` TaskID string `p:"taskId" json:"taskId" dc:"任务ID模糊匹配"`
State *int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"` State int `p:"state" json:"state" dc:"任务状态0/1/2/3/4可选"`
} }
type ListTaskRes struct { type ListTaskRes struct {
@@ -102,12 +105,3 @@ type RunWorkReq struct {
type RunWorkRes struct { type RunWorkRes struct {
Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"` Claimed int `json:"claimed" dc:"本次抢占并处理的任务数"`
} }
// CleanWorkReq 手动触发 cleaner 执行一次(由上层定时任务调用)
type CleanWorkReq struct {
g.Meta `path:"/cleanWork" method:"post" tags:"任务管理" summary:"执行一次Cleaner" dc:"手动触发一次清理/重试(用于由上层定时任务控制)"`
}
type CleanWorkRes struct {
Ok bool `json:"ok" dc:"是否执行成功"`
}

View File

@@ -1,20 +0,0 @@
package dto
import "github.com/gogf/gf/v2/frame/g"
// ListModelStatReq 统计列表
type ListModelStatReq struct {
g.Meta `path:"/listModelStat" method:"get" tags:"统计" summary:"模型请求统计列表" dc:"按天统计模型请求次数,支持分页与条件筛选"`
PageNum int `p:"pageNum" json:"pageNum" dc:"页码默认1"`
PageSize int `p:"pageSize" json:"pageSize" dc:"每页条数默认10"`
StartDay string `p:"startDay" json:"startDay" dc:"开始日期YYYY-MM-DD可选"`
EndDay string `p:"endDay" json:"endDay" dc:"结束日期YYYY-MM-DD可选"`
TenantID *int64 `p:"tenantId" json:"tenantId" dc:"租户ID可选"`
Creator string `p:"creator" json:"creator" dc:"创建人(可选,模糊匹配)"`
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(可选,模糊匹配)"`
}
type ListModelStatRes struct {
List any `json:"list" dc:"列表数据"`
Total int64 `json:"total" dc:"总数"`
}

View File

@@ -1,103 +0,0 @@
package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
type asynchModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
ResponseTokenField string
RequiredFields string
IsPrivate string
IsChatModel string
CallMode string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
CallbackUrl string
}
var AsynchModelCol = asynchModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
ResponseBody: "response_body",
ResponseTokenField: "response_token_field",
RequiredFields: "required_fields",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
CallMode: "call_mode",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
CallbackUrl: "callback_url",
}
// AsynchModel 异步模型配置
type AsynchModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody string `orm:"response_body" json:"responseBody"`
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
CallMode *int `orm:"call_mode" json:"callMode"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `json:"isOwner" orm:"is_owner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
CallbackUrl string `orm:"callback_url" json:"callbackUrl"`
}

View File

@@ -1,89 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/os/gtime"
)
type asynchTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
ModelKey string
State string
OssFile string
FileType string
FileSize string
ErrorMsg string
StartedAt string
FinishedAt string
DurationSeconds string
ExpireAt string
RetryCount string
EnqueueAt string
Phase string
TmpFile string
InputRef string
RequestPayload string
TextResult string
EpicycleId string
ExpendTokens string
}
var AsynchTaskCol = asynchTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
ModelKey: "model_key",
State: "state",
OssFile: "oss_file",
FileType: "file_type",
FileSize: "file_size",
ErrorMsg: "error_msg",
StartedAt: "started_at",
FinishedAt: "finished_at",
DurationSeconds: "duration_seconds",
ExpireAt: "expire_at",
RetryCount: "retry_count",
EnqueueAt: "enqueue_at",
Phase: "phase",
TmpFile: "tmp_file",
InputRef: "input_ref",
RequestPayload: "request_payload",
TextResult: "text_result",
EpicycleId: "epicycle_id",
ExpendTokens: "expend_tokens",
}
// AsynchTask 异步任务
type AsynchTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
ModelKey string `orm:"model_key" json:"modelKey"`
State int `orm:"state" json:"state"` // 0排队中/1执行中/2成功/3失败/4已下载
OssFile string `orm:"oss_file" json:"ossFile"`
FileType string `orm:"file_type" json:"fileType"`
FileSize int64 `orm:"file_size" json:"fileSize"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
StartedAt *gtime.Time `orm:"started_at" json:"startedAt"`
FinishedAt *gtime.Time `orm:"finished_at" json:"finishedAt"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
ExpireAt *gtime.Time `orm:"expire_at" json:"expireAt"` // 已下载(state=4)后的过期时间
RetryCount int `orm:"retry_count" json:"retryCount"`
EnqueueAt *gtime.Time `orm:"enqueue_at" json:"enqueueAt"`
Phase int `orm:"phase" json:"phase"` // 0模型阶段/1OSS阶段
TmpFile string `orm:"tmp_file" json:"tmpFile"` // 临时结果文件路径
InputRef string `orm:"input_ref" json:"inputRef"`
RequestPayload map[string]any `orm:"request_payload" json:"requestPayload"`
TextResult map[string]any `orm:"text_result" json:"text"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"` // 轮次ID用于标识同一轮次的任务
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"` // 消耗 token 数
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"-"`
}

View File

@@ -1,57 +0,0 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type LogsModelPpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var LogsModelOpCol = LogsModelPpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// LogsModelOp 操作日志(创建任务等)
type LogsModelOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload any `orm:"request_payload" json:"requestPayload"`
ResponsePayload any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -1,38 +0,0 @@
package entity
import (
"github.com/gogf/gf/v2/os/gtime"
)
// LogsModelStatCol 字段常量
type LogsModelStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var LogsModelStatCols = LogsModelStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// LogsModelStat 按天统计:某天/租户/创建人/模型的请求次数
// 注:这里不走通用 SQLBaseDO采用联合唯一键day,tenant_id,creator,model_name做 UPSERT 原子累加。
type LogsModelStat struct {
Day *gtime.Time `orm:"day" json:"day"` // 日期(建议仅使用日期部分)
TenantId int64 `orm:"tenant_id" json:"tenantId"` // 租户ID
Creator string `orm:"creator" json:"creator"` // 创建人/操作人
ModelName string `orm:"model_name" json:"modelName"` // 模型名称
RequestCount int64 `orm:"request_count" json:"requestCount"` // 请求次数
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` // 创建时间
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` // 更新时间
}

View File

@@ -0,0 +1,56 @@
package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
// ModelGatewayLogsOpCol 字段常量
type modelGatewayLogsOpCol struct {
beans.SQLBaseCol
IP string
UserAgent string
APIPath string
HttpMethod string
BizName string
ModelName string
TaskID string
OpType string
Success string
ErrorMsg string
CostMs string
RequestPayload string
ResponsePayload string
}
var ModelGatewayLogsOpCol = modelGatewayLogsOpCol{
SQLBaseCol: beans.DefSQLBaseCol,
IP: "ip",
UserAgent: "user_agent",
APIPath: "api_path",
HttpMethod: "http_method",
BizName: "biz_name",
ModelName: "model_name",
TaskID: "task_id",
OpType: "op_type",
Success: "success",
ErrorMsg: "error_msg",
CostMs: "cost_ms",
RequestPayload: "request_payload",
ResponsePayload: "response_payload",
}
// ModelGatewayLogsOp 操作日志
type ModelGatewayLogsOp struct {
beans.SQLBaseDO `orm:",inline"`
IP string `orm:"ip" json:"ip"`
UserAgent string `orm:"user_agent" json:"userAgent"`
APIPath string `orm:"api_path" json:"apiPath"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
BizName string `orm:"biz_name" json:"bizName"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
OpType string `orm:"op_type" json:"opType"`
Success int `orm:"success" json:"success"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
CostMs int64 `orm:"cost_ms" json:"costMs"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
ResponsePayload map[string]any `orm:"response_payload" json:"responsePayload"`
}

View File

@@ -0,0 +1,35 @@
package entity
import "github.com/gogf/gf/v2/os/gtime"
// ModelGatewayLogsStatCol 字段常量
type ModelGatewayLogsStatCol struct {
Day string
TenantId string
Creator string
ModelName string
RequestCount string
CreatedAt string
UpdatedAt string
}
var ModelGatewayLogsStatCols = ModelGatewayLogsStatCol{
Day: "day",
TenantId: "tenant_id",
Creator: "creator",
ModelName: "model_name",
RequestCount: "request_count",
CreatedAt: "created_at",
UpdatedAt: "updated_at",
}
// ModelGatewayLogsStat 按天统计
type ModelGatewayLogsStat struct {
Day *gtime.Time `orm:"day" json:"day"`
TenantId uint64 `orm:"tenant_id" json:"tenantId"`
Creator string `orm:"creator" json:"creator"`
ModelName string `orm:"model_name" json:"modelName"`
RequestCount int64 `orm:"request_count" json:"requestCount"`
CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"`
UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"`
}

View File

@@ -0,0 +1,99 @@
package entity
import "gitea.redpowerfuture.com/red-future/common/beans"
type modelGatewayModelCol struct {
beans.SQLBaseCol
ModelName string
ModelType string
BaseURL string
HttpMethod string
HeadMsg string
FormJSON string
RequestMapping string
ResponseMapping string
ResponseBody string
RequiredFields string
IsPrivate string
IsChatModel string
CallMode string
ApiKey string
Enabled string
MaxConcurrency string
TimeoutSeconds string
RetryTimes string
AutoCleanSeconds string
IsOwner string
OperatorName string
TokenConfig string
ExtendMapping string
QueryConfig string
StreamConfig string
FirstFrame string
LastFrame string
}
var ModelGatewayModelCol = modelGatewayModelCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
ModelType: "model_type",
BaseURL: "base_url",
HttpMethod: "http_method",
HeadMsg: "head_msg",
FormJSON: "form_json",
RequestMapping: "request_mapping",
ResponseMapping: "response_mapping",
RequiredFields: "required_fields",
IsPrivate: "is_private",
IsChatModel: "is_chat_model",
CallMode: "call_mode",
ApiKey: "api_key",
Enabled: "enabled",
MaxConcurrency: "max_concurrency",
TimeoutSeconds: "timeout_seconds",
RetryTimes: "retry_times",
AutoCleanSeconds: "auto_clean_seconds",
IsOwner: "is_owner",
OperatorName: "operator_name",
TokenConfig: "token_config",
ExtendMapping: "extend_mapping",
QueryConfig: "query_config",
StreamConfig: "stream_config",
FirstFrame: "first_frame",
LastFrame: "last_frame",
}
type ModelGatewayModel struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
ModelType int `orm:"model_type" json:"modelType"`
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg map[string]any `orm:"head_msg" json:"headMsg"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
RequiredFields []string `orm:"required_fields" json:"requiredFields"`
IsPrivate *int `orm:"is_private" json:"isPrivate"`
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
CallMode *int `orm:"call_mode" json:"callMode"`
ApiKey string `orm:"api_key" json:"apiKey"`
Enabled *int `orm:"enabled" json:"enabled"`
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
RetryTimes int `orm:"retry_times" json:"retryTimes"`
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
IsOwner *int `orm:"is_owner" json:"isOwner"`
OperatorName string `orm:"operator_name" json:"operatorName"`
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
FirstFrame string `orm:"first_frame" json:"firstFrame"`
LastFrame string `orm:"last_frame" json:"lastFrame"`
}
const ( //ResponseMapping 下的字段
ResponseBody = "response_body" //返回主体
TotalTokens = "total_tokens" //总token数
)

View File

@@ -0,0 +1,76 @@
package entity
import (
"gitea.redpowerfuture.com/red-future/common/beans"
)
type modelGatewayTaskCol struct {
beans.SQLBaseCol
ModelName string
TaskID string
BizName string
CallbackURL string
State string
Phase string
ErrorMsg string
ResultFile string
TextResult string
ExpendTokens string
DurationSeconds string
RetryCount string
TmpFile string
RequestPayload string
EpicycleId string
}
var ModelGatewayTaskCol = modelGatewayTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
ModelName: "model_name",
TaskID: "task_id",
BizName: "biz_name",
CallbackURL: "callback_url",
State: "state",
Phase: "phase",
ErrorMsg: "error_msg",
ResultFile: "result_file",
TextResult: "text_result",
ExpendTokens: "expend_tokens",
DurationSeconds: "duration_seconds",
RetryCount: "retry_count",
TmpFile: "tmp_file",
RequestPayload: "request_payload",
EpicycleId: "epicycle_id",
}
// ModelGatewayTask 模型网关任务
type ModelGatewayTask struct {
beans.SQLBaseDO `orm:",inline"`
ModelName string `orm:"model_name" json:"modelName"`
TaskID string `orm:"task_id" json:"taskId"`
BizName string `orm:"biz_name" json:"bizName"`
CallbackURL string `orm:"callback_url" json:"callbackUrl"`
State int `orm:"state" json:"state"`
Phase int `orm:"phase" json:"phase"`
ErrorMsg string `orm:"error_msg" json:"errorMsg"`
ResultFile *ResultFile `orm:"result_file" json:"resultFile"`
TextResult map[string]any `orm:"text_result" json:"text"`
ExpendTokens int64 `orm:"expend_tokens" json:"expendTokens"`
DurationSeconds int64 `orm:"duration_seconds" json:"durationSeconds"`
RetryCount int `orm:"retry_count" json:"retryCount"`
TmpFile string `orm:"tmp_file" json:"tmpFile"`
RequestPayload *RequestPayload `orm:"request_payload" json:"requestPayload"`
EpicycleId int64 `orm:"epicycle_id" json:"epicycleId"`
}
// ResultFile OSS 结果文件
type ResultFile struct {
OssFile string `json:"ossFile"`
FileType string `json:"fileType"`
FileSize int64 `json:"fileSize"`
}
// RequestPayload 请求参数结构体
type RequestPayload struct {
Headers map[string]string `json:"headers"`
Body map[string]any `json:"body"`
}

View File

@@ -24,6 +24,7 @@ type UploadFileResponse struct {
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀 FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
} }
// UploadByTask 通过任务上传文件
func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) { func UploadByTask(ctx context.Context, data []byte, fileExt string) (oss *UploadFileResponse, err error) {
// multipart // multipart
body := &bytes.Buffer{} body := &bytes.Buffer{}
@@ -72,20 +73,18 @@ type CallbackPayload struct {
State int `json:"state"` State int `json:"state"`
OssFile string `json:"oss_file"` OssFile string `json:"oss_file"`
FileType string `json:"file_type"` FileType string `json:"file_type"`
Messages map[string]any `json:"messages"`
ErrorMsg string `json:"error_msg"` ErrorMsg string `json:"error_msg"`
} }
// TriggerCallback 任务的回调 // TriggerCallback 任务的回调
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { func TriggerCallback(ctx context.Context, t *entity.ModelGatewayTask) {
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp struct{} var resp struct{}
payload := CallbackPayload{ payload := CallbackPayload{
TaskId: t.TaskID, TaskId: t.TaskID,
State: t.State, State: t.State,
OssFile: t.OssFile, OssFile: t.ResultFile.OssFile,
FileType: t.FileType, FileType: t.ResultFile.FileType,
Messages: t.TextResult,
ErrorMsg: t.ErrorMsg, ErrorMsg: t.ErrorMsg,
} }
jsonData, err := json.Marshal(payload) jsonData, err := json.Marshal(payload)
@@ -111,7 +110,7 @@ type PromptsCallbackPayload struct {
} }
// TriggerPromptsCallback 任务成功后的提示词回调 // TriggerPromptsCallback 任务成功后的提示词回调
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { func TriggerPromptsCallback(ctx context.Context, t *entity.ModelGatewayTask, epicycleId int64) {
callbackURL := "prompts-core/session/callback" callbackURL := "prompts-core/session/callback"
headers := util.ForwardHeaders(ctx) headers := util.ForwardHeaders(ctx)
var resp struct{} var resp struct{}

View File

@@ -1,99 +0,0 @@
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
"model-gateway/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
// 获取模型配置以得到 queue_limit / expected_seconds
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || m == nil {
continue
}
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.MaxConcurrency*2)
if limit > 0 {
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.TimeoutSeconds)
if !ok {
continue
}
}
// retry_queue_max_seconds 控制失败重试的排队策略:
// - =0失败重试插队到队首
// - >0当任务从创建到现在的排队时长 >= maxSeconds则插队到队首否则仍放到队尾
now := time.Now()
enqueueAt := now
maxSeconds := t.RetryQueueMaxSeconds
if maxSeconds == 0 {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
} else if maxSeconds > 0 && t.CreatedAt != nil {
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
}
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -18,7 +18,7 @@ import (
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
var Model = &modelService{} var ModelGatewayModels = &modelService{}
type modelService struct{} type modelService struct{}
@@ -37,7 +37,7 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dt
} }
// 3入库 // 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) id, err := dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -56,27 +56,27 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
req.IsOwner = gconv.PtrInt(1) req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0) req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err := dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
// 3跨租户判断超管的模型不允许直接修改走插入新记录 // 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
}) })
if err != nil { if err != nil {
return err return err
} }
if model.TenantId == 1 { if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err = dao.ModelGatewayModels.Insert(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) _, err = dao.ModelGatewayModels.Update(ctx, util.ConvertTo[entity.ModelGatewayModel](req))
return err return err
} }
// Delete 删除模型 // Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{ _, err := dao.ModelGatewayModels.Delete(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
}) })
return err return err
@@ -91,7 +91,7 @@ func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetM
if g.IsEmpty(req.ID) { if g.IsEmpty(req.ID) {
req.Creator = user.UserName req.Creator = user.UserName
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
Id: req.ID, Id: req.ID,
Creator: user.UserName, Creator: user.UserName,
@@ -123,7 +123,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
req.Creator = user.UserName req.Creator = user.UserName
// 3查询 // 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) models, total, err := dao.ModelGatewayModels.GetByCreatorAndPlatform(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -134,7 +134,7 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.Li
// UpdateChatModel 设置会话模型 // UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在 // 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ newModel, err := dao.ModelGatewayModels.GetByAcrossTenant(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
}) })
if err != nil || newModel == nil { if err != nil || newModel == nil {
@@ -146,7 +146,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
if err != nil { if err != nil {
return err return err
} }
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ currentModel, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -161,7 +161,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return errors.New("当前模型为非推理模型,不能设置为会话模型") return errors.New("当前模型为非推理模型,不能设置为会话模型")
} }
if currentModel.Id != req.Id { if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0), IsChatModel: gconv.PtrInt(0),
}) })
@@ -171,7 +171,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
} }
} }
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -185,7 +185,7 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR
if err != nil { if err != nil {
return nil, err return nil, err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
@@ -203,14 +203,14 @@ func (s *modelService) clearUserChatModel(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })
if err != nil || model == nil { if err != nil || model == nil {
return nil return nil
} }
_, err = dao.Model.Update(ctx, &entity.AsynchModel{ _, err = dao.ModelGatewayModels.Update(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0), IsChatModel: gconv.PtrInt(0),
}) })
@@ -223,7 +223,7 @@ func (s *modelService) checkChatModelUnique(ctx context.Context) error {
if err != nil { if err != nil {
return err return err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: gconv.PtrInt(1), IsChatModel: gconv.PtrInt(1),
}) })

View File

@@ -43,14 +43,14 @@ func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes,
req.WindowSeconds = 3600 // 默认1小时 req.WindowSeconds = 3600 // 默认1小时
} }
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限) // 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel var modelRows []*entity.ModelGatewayModel
if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). if err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
Where("deleted_at IS NULL"). Where("deleted_at IS NULL").
Where(entity.AsynchModelCol.Enabled, 1). Where(entity.ModelGatewayModelCol.Enabled, 1).
Scan(&modelRows); err != nil { Scan(&modelRows); err != nil {
return nil, err return nil, err
} }
modelMap := make(map[string]*entity.AsynchModel) modelMap := make(map[string]*entity.ModelGatewayModel)
for _, m := range modelRows { for _, m := range modelRows {
if m == nil || m.ModelName == "" { if m == nil || m.ModelName == "" {
continue continue

View File

@@ -2,36 +2,31 @@ package stat
import ( import (
"context" "context"
"model-gateway/model/entity"
"model-gateway/dao" "model-gateway/dao"
"model-gateway/model/dto" "model-gateway/model/dto"
) )
type statService struct{} var ModelGatewayLogsStat = &logsStatService{}
var Stat = &statService{} type logsStatService struct{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { func (s *logsStatService) List(ctx context.Context, req *dto.ListModelStatReq) (*dto.ListModelStatRes, error) {
pageNum, pageSize := 1, 10 if req == nil {
if req != nil { req = &dto.ListModelStatReq{}
if req.PageNum > 0 {
pageNum = req.PageNum
} }
if req.PageSize > 0 { if req.PageNum <= 0 {
pageSize = req.PageSize req.PageNum = 1
} }
if req.PageSize <= 0 {
req.PageSize = 10
} }
startDay, endDay := "", ""
var tenantID *int64 list, total, err := dao.ModelGatewayLogsStat.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayLogsStat{
creator, modelName := "", "" Creator: req.Creator,
if req != nil { ModelName: req.ModelName,
startDay = req.StartDay })
endDay = req.EndDay
tenantID = req.TenantID
creator = req.Creator
modelName = req.ModelName
}
list, total, err := dao.Stat.List(ctx, pageNum, pageSize, startDay, endDay, tenantID, creator, modelName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -17,25 +17,24 @@ import (
"gitea.redpowerfuture.com/red-future/common/utils" "gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
"github.com/google/uuid" "github.com/google/uuid"
) )
var Task = &taskService{} var ModelGatewayTask = &taskService{}
type taskService struct{} type taskService struct{}
// Create 创建任务 // Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
taskID := uuid.NewString() taskID := uuid.NewString()
// 1) 检查模型配置,并且获取模型 // 1) 检查模型配置,并且获取模型
userInfo, err := utils.GetUserInfo(ctx) userInfo, err := utils.GetUserInfo(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model, err := dao.Model.Get(ctx, &entity.AsynchModel{ model, err := dao.ModelGatewayModels.Get(ctx, &entity.ModelGatewayModel{
SQLBaseDO: beans.SQLBaseDO{ SQLBaseDO: beans.SQLBaseDO{
TenantId: userInfo.TenantId, TenantId: userInfo.TenantId,
Creator: userInfo.UserName, Creator: userInfo.UserName,
@@ -62,23 +61,17 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
} }
// 3) 插入任务记录 // 3) 插入任务记录
if model.CallMode != nil && *model.CallMode == public.CallModeAsync { requestPayload := entity.RequestPayload{
// 异步调用:注入回调地址后提交,拿到 task_id 轮询 Body: req.RequestPayload,
req.RequestPayload = util.InjectCallbackURL(ctx, req.RequestPayload, model.CallbackUrl) Headers: util.ParseHeadMsgHeaders(model.HeadMsg),
} }
storedPayload := map[string]any{ id, err := dao.ModelGatewayTask.Insert(ctx, &entity.ModelGatewayTask{
"headers": util.ParseHeadMsgHeaders(model.HeadMsg),
"body": req.RequestPayload,
}
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName, ModelName: req.ModelName,
TaskID: taskID, TaskID: taskID,
State: 0, State: public.TaskStatusPending,
BizName: req.BizName, BizName: req.BizName,
CallbackURL: req.CallbackUrl, CallbackURL: req.CallbackUrl,
ModelKey: model.ApiKey, RequestPayload: &requestPayload,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId, EpicycleId: req.EpicycleId,
}) })
if err != nil { // 入库失败:回滚闸门占位 if err != nil { // 入库失败:回滚闸门占位
@@ -97,7 +90,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path apiPath = r.URL.Path
httpMethod = r.Method httpMethod = r.Method
} }
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{ _, _ = dao.ModelGatewayLogsOp.Insert(ctx, &entity.ModelGatewayLogsOp{
IP: ip, IP: ip,
UserAgent: ua, UserAgent: ua,
APIPath: apiPath, APIPath: apiPath,
@@ -107,22 +100,18 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
TaskID: taskID, TaskID: taskID,
OpType: "createTask", OpType: "createTask",
Success: 1, Success: 1,
ErrorMsg: "", CostMs: time.Since(time.Now()).Milliseconds(),
CostMs: time.Since(startAt).Milliseconds(), RequestPayload: &requestPayload,
RequestPayload: storedPayload,
ResponsePayload: gdb.Map{ ResponsePayload: gdb.Map{
"taskId": taskID, "taskId": taskID,
}, },
}) })
// 5) 获取任务信息 // 5) 获取任务信息
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID) task, err := dao.ModelGatewayTask.ClaimByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if task == nil {
return nil, err
}
// 5) 创建成功后立即异步尝试执行当前任务 // 5) 创建成功后立即异步尝试执行当前任务
go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req) go AsyncWorker.handleOne(util.AsyncCtx(ctx), task, model, req)
@@ -130,10 +119,96 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return &dto.CreateTaskRes{TaskID: taskID}, nil return &dto.CreateTaskRes{TaskID: taskID}, nil
} }
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.ResultFile.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.ModelGatewayTask.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)
for _, t := range list {
if t == nil {
continue
}
if t.State != public.BuildTypeNode {
continue
}
_ = dao.ModelGatewayTask.MarkDownloadedByID(ctx, t.Id)
// 为了本次返回一致性,内存里也更新
t.State = public.TaskStatusDownloaded
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.ResultFile.OssFile,
TextResult: t.TextResult,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (*dto.ListTaskRes, error) {
if req.PageNum <= 0 {
req.PageNum = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
list, total, err := dao.ModelGatewayTask.List(ctx, req.PageNum, req.PageSize, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
ModelName: req.ModelName,
BizName: req.BizName,
State: req.State,
TaskID: req.TaskID,
})
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}
// ModelTaskCallback 模型异步任务的回调通知
func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) { func (s *taskService) ModelTaskCallback(ctx context.Context, req *dto.ModelTaskCallbackReq) (*dto.ModelTaskCallbackRes, error) {
g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status) g.Log().Infof(ctx, "[模型回调] 收到通知 taskID=%s status=%s", req.TaskID, req.Status)
// 1. 查本地任务 // 1. 查本地任务
task, err := dao.Task.Get(ctx, &entity.AsynchTask{ task, err := dao.ModelGatewayTask.Get(ctx, &entity.ModelGatewayTask{
TaskID: req.TaskID, TaskID: req.TaskID,
}) })
if err != nil || task == nil { if err != nil || task == nil {
@@ -167,7 +242,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
} }
// 1. 查 state=1执行中的异步任务 // 1. 查 state=1执行中的异步任务
tasks, err := dao.Task.GetPendingAsyncTasks(ctx, limit) tasks, err := dao.ModelGatewayTask.GetPendingAsyncTasks(ctx, limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -176,7 +251,7 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
var results []dto.QueryTaskItem var results []dto.QueryTaskItem
for _, t := range tasks { for _, t := range tasks {
// 拿到模型配置 // 拿到模型配置
model, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName) model, err := dao.ModelGatewayModels.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil || model == nil || model.QueryConfig == nil { if err != nil || model == nil || model.QueryConfig == nil {
continue continue
} }
@@ -206,100 +281,3 @@ func (s *taskService) QueryPendingTasks(ctx context.Context, req *dto.QueryPendi
Results: results, Results: results,
}, nil }, nil
} }
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""
taskID := ""
var state *int
if req != nil {
modelName = req.ModelName
taskID = req.TaskID
state = req.State
}
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
if err != nil {
return nil, err
}
return &dto.ListTaskRes{List: list, Total: total}, nil
}

View File

@@ -21,6 +21,7 @@ import (
"model-gateway/service/gateway" "model-gateway/service/gateway"
"model-gateway/service/queue" "model-gateway/service/queue"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
@@ -32,14 +33,19 @@ type asyncWorker struct {
} }
// handleOne 执行一次完整的任务 // handleOne 执行一次完整的任务
func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq) { func (w *asyncWorker) handleOne(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq) {
body := util.GetModelBody(task.RequestPayload) // 核心请求参数 var (
maxRetry := model.RetryTimes // 重试次数 body = task.RequestPayload.Body
startTime := time.Now() maxRetry = model.RetryTimes
startTime = time.Now()
result map[string]any
err error
)
g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName) g.Log().Infof(ctx, "[执行任务][开始] taskId=%s model=%s", task.TaskID, task.ModelName)
// ============================================
// 1) 分布式并发控制 // 1) 分布式并发控制
// ============================================
semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName) semKey := fmt.Sprintf("asynch:sem:%s", task.ModelName)
maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency) maxC := queue.GetRuntimeMaxConcurrency(ctx, task.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
@@ -49,88 +55,93 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
return return
} }
if !acquired { if !acquired {
_, _ = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
State: public.TaskStatusPending,
})
g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID) g.Log().Infof(ctx, "[执行任务][排队] 并发已满,放回队列 taskId=%s", task.TaskID)
_ = w.rollbackToPending(ctx, task.Id)
return return
} }
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// ============================================
// 2) 调用模型 // 2) 调用模型
// ============================================
switch { switch {
case model.CallMode != nil && *model.CallMode == public.CallModeStream: case model.CallMode != nil && *model.CallMode == public.CallModeStream:
rawBytes, err := w.callModelStream(ctx, task, model, body) rawBytes, streamErr := w.callModelStream(ctx, task, model, body)
if err != nil { if streamErr != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, streamErr.Error())
return
}
body, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return return
} }
result, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
case model.CallMode != nil && *model.CallMode == public.CallModeAsync: case model.CallMode != nil && *model.CallMode == public.CallModeAsync:
body, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
if err != nil { if err == nil {
w.failTask(ctx, task, startTime, err.Error()) result, err = util.PullTaskResult(ctx, result, model.QueryConfig, model.HeadMsg)
return
}
body, err = util.PullTaskResult(ctx, body, model.QueryConfig, model.HeadMsg)
if err != nil {
w.failTask(ctx, task, startTime, err.Error())
return
} }
default: default:
body, err = w.callModel(ctx, task, model, body) result, err = w.callModel(ctx, task, model, body)
}
if err != nil { if err != nil {
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
}
// 3) 保存临时文件 // ============================================
tmpPath, err := util.SaveTempFileByType(task.TaskID, body, task.TmpFile) // 3) 缓存临时文件
if err == nil && tmpPath != "" { // ============================================
if tmpPath, tmpErr := util.SaveTempFileByType(task.TaskID, result, task.TmpFile); tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, _ = dao.ModelGatewayTask.Update(ctx, task)
} }
// 4) 解析校验 + 响应映射(可重试,失败重新调模型) // ============================================
body, err = w.parseAndRetry(ctx, body, task, model, req, maxRetry, startTime) // 4) 解析校验 + 响应映射(可重试)
// ============================================
result, err = w.parseAndRetry(ctx, result, task, model, req, maxRetry, startTime)
if err != nil { if err != nil {
task.TextResult = body task.TextResult = result
w.failTask(ctx, task, startTime, err.Error()) w.failTask(ctx, task, startTime, err.Error())
return return
} }
// ============================================
// 5) 上传 OSS可重试 // 5) 上传 OSS可重试
// ============================================
var oss *gateway.UploadFileResponse var oss *gateway.UploadFileResponse
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] OSS上传 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
} }
oss, err = w.uploadOSS(ctx, task) oss, err = gateway.UploadByTask(ctx, gjson.New(result).MustToJson(), "json")
if err == nil { if err == nil {
break break
} }
g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", g.Log().Errorf(ctx, "[执行任务][失败] OSS上传失败 taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, err)
task.TaskID, attempt, maxRetry, err)
if attempt == maxRetry { if attempt == maxRetry {
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, task.Id, err.Error()) task.State = public.TaskStatusFailed
task.ErrorMsg = err.Error()
task.Phase = 1
_, _ = dao.ModelGatewayTask.Update(ctx, task)
w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err)) w.failTask(ctx, task, startTime, fmt.Sprintf("OSS上传重试耗尽: %v", err))
return return
} }
} }
// 6) 成功回调 // ============================================
task.State = 2 // 6) 成功收尾
// ============================================
task.State = public.TaskStatusSuccess
task.DurationSeconds = int64(time.Since(startTime).Seconds()) task.DurationSeconds = int64(time.Since(startTime).Seconds())
task.OssFile = oss.FileAddressPrefix + oss.FileURL task.ResultFile = &entity.ResultFile{
task.FileType = oss.FileFormat OssFile: oss.FileAddressPrefix + oss.FileURL,
task.TextResult = body FileType: oss.FileFormat,
task.FileSize = int64(oss.FileSize) FileSize: int64(oss.FileSize),
}
if err = dao.Task.UpdateSuccessGlobal(ctx, task); err != nil { task.TextResult = result
if _, err = dao.ModelGatewayTask.Update(ctx, task); err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err) g.Log().Errorf(ctx, "[执行任务][失败] 更新数据库失败 taskId=%s err=%v", task.TaskID, err)
return return
} }
@@ -141,15 +152,14 @@ func (w *asyncWorker) handleOne(ctx context.Context, task *entity.AsynchTask, mo
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId) go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), task, req.EpicycleId)
} }
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s textLen=%d callbackUrl=%s", g.Log().Infof(ctx, "[执行任务][成功] taskId=%s duration=%ds fileType=%s",
task.TaskID, task.DurationSeconds, oss.FileFormat, len(body), task.CallbackURL) task.TaskID, task.DurationSeconds, oss.FileFormat)
// 7) 删除临时文件
_ = os.Remove(task.TmpFile) _ = os.Remove(task.TmpFile)
} }
// callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出) // callModelStream 调用模型,返回原始字节(不做响应映射,用于流式输出)
func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) ([]byte, error) { func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
var data []byte var data []byte
var err error var err error
@@ -161,8 +171,7 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
} }
if data == nil { if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -170,7 +179,10 @@ func (w *asyncWorker) callModelStream(ctx context.Context, task *entity.AsynchTa
if tmpErr == nil && tmpPath != "" { if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
} }
} }
@@ -186,14 +198,14 @@ type asyncResult struct {
// asyncTaskChan 全局异步任务等待通道 // asyncTaskChan 全局异步任务等待通道
var asyncTaskChan = sync.Map{} // taskID → chan asyncResult var asyncTaskChan = sync.Map{} // taskID → chan asyncResult
func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { func (w *asyncWorker) callModelAsync(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
// 1. 提交异步任务 // 1. 提交异步任务
body, err := w.callModel(ctx, task, model, body) body, err := w.callModel(ctx, task, model, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 2. 拿到 task_id // 2. 拿到 task_id
taskID := gjson.New(body).Get(model.ResponseBody).String() taskID := gjson.New(body).Get(entity.ResponseBody).String()
// 3. 创建等待通道 // 3. 创建等待通道
ch := make(chan asyncResult, 1) ch := make(chan asyncResult, 1)
@@ -231,7 +243,7 @@ func NotifyAsyncResult(taskID string, result map[string]any, err error) {
// callModel 调用模型 + 检测文件类型 + 保存临时文件 // callModel 调用模型 + 检测文件类型 + 保存临时文件
// 返回: 解析后的响应体, error // 返回: 解析后的响应体, error
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, body map[string]any) (map[string]any, error) { func (w *asyncWorker) callModel(ctx context.Context, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, body map[string]any) (map[string]any, error) {
var data []byte var data []byte
var err error var err error
@@ -246,8 +258,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
// 2) 没有可用数据,调用模型 // 2) 没有可用数据,调用模型
if data == nil { if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) data, err = InvokeModel(ctx, model, body)
data, err = InvokeModel(ctx, model, body, task.ModelKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -258,7 +269,10 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
if tmpErr == nil && tmpPath != "" { if tmpErr == nil && tmpPath != "" {
task.TmpFile = tmpPath task.TmpFile = tmpPath
task.Phase = 1 task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) _, err = dao.ModelGatewayTask.Update(ctx, task)
if err != nil {
g.Log().Errorf(ctx, "[执行任务][失败] 临时文件保存失败 taskId=%s err=%v", task.TaskID, tmpErr)
}
} }
} }
@@ -279,7 +293,7 @@ func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, mo
} }
// parseAndRetry 解析模型返回结果,并重试 // parseAndRetry 解析模型返回结果,并重试
func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.AsynchTask, model *entity.AsynchModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) { func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, task *entity.ModelGatewayTask, model *entity.ModelGatewayModel, req *dto.CreateTaskReq, maxRetry int, startTime time.Time) (map[string]any, error) {
for attempt := 0; attempt <= maxRetry; attempt++ { for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 { if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID) g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, task.TaskID)
@@ -296,10 +310,11 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
} }
// 2) 先存 token 到数据库,防止后续失败丢失 // 2) 先存 token 到数据库,防止后续失败丢失
if tokens, ok := mapped[model.ResponseTokenField]; ok { if _, ok := mapped[entity.TotalTokens]; ok {
task.ExpendTokens = gconv.Int64(tokens) task.ExpendTokens = gconv.Int64(mapped[entity.TotalTokens])
_ = dao.Task.UpdateColumns(ctx, task.Id, entity.AsynchTask{ _, err = dao.ModelGatewayTask.Update(ctx, &entity.ModelGatewayTask{
ExpendTokens: gconv.Int64(body[model.ResponseTokenField]), SQLBaseDO: beans.SQLBaseDO{Id: task.Id},
ExpendTokens: task.ExpendTokens,
}) })
} }
@@ -312,7 +327,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
return parsed, nil return parsed, nil
} }
case public.BuildTypeStruct: case public.BuildTypeStruct:
parsed = util.ParseStructResult(mapped, model.ResponseBody) parsed = util.ParseStructResult(mapped, entity.ResponseBody)
return parsed, nil return parsed, nil
default: default:
return mapped, nil return mapped, nil
@@ -325,9 +340,10 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
} }
// 4) 重新调模型(直接调,不走缓存) // 4) 重新调模型(直接调,不走缓存)
_ = dao.Task.IncRetryCountGlobal(ctx, task.Id) task.RetryCount++
reqBody := util.GetModelBody(task.RequestPayload) _, _ = dao.ModelGatewayTask.Update(ctx, task)
rawData, callErr := InvokeModel(ctx, model, reqBody, task.ModelKey) rawData, callErr := InvokeModel(ctx, model, task.RequestPayload.Body)
if callErr != nil { if callErr != nil {
g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr) g.Log().Warningf(ctx, "[执行任务][重调模型失败] taskId=%s attempt=%d/%d err=%v", task.TaskID, attempt, maxRetry, callErr)
continue continue
@@ -335,7 +351,7 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// 5) 解析原始响应,覆盖 body 进入下一轮 // 5) 解析原始响应,覆盖 body 进入下一轮
var rawResp map[string]any var rawResp map[string]any
if err := json.Unmarshal(rawData, &rawResp); err != nil { if err = json.Unmarshal(rawData, &rawResp); err != nil {
g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err) g.Log().Warningf(ctx, "[执行任务][Unmarshal失败] taskId=%s err=%v", task.TaskID, err)
continue continue
} }
@@ -347,18 +363,21 @@ func (w *asyncWorker) parseAndRetry(ctx context.Context, body map[string]any, ta
// InvokeModel 调用模型服务,返回二进制结果 // InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key // modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string]any, modelKey string) ([]byte, error) { func InvokeModel(ctx context.Context, model *entity.ModelGatewayModel, body map[string]any) ([]byte, error) {
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 // 1) 记录模型调用次数
_ = dao.ModelGatewayLogsStat.IncRequestCount(ctx, time.Now(), model.TenantId, model.Creator, model.ModelName)
// 2请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射 //—— 请求映射实际处理为提示词构建请求,因为有附加字段及其他字段的拼接。这里不方便做请求映射
//mappedPayload := util.ReverseMap(model.RequestMapping, payload) //mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2)构建请求 URL 和超时 // 3)构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/") baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout} client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3)构建 HTTP 请求 // 4)构建 HTTP 请求
var req *http.Request var req *http.Request
switch method { switch method {
case http.MethodGet: case http.MethodGet:
@@ -382,31 +401,31 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
} }
// 4)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者 // 5)注入请求头:先模型静态配置,再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv) req.Header.Set(hk, hv)
} }
if modelKey != "" { if model.ApiKey != "" {
req.Header.Set("Authorization", "Bearer "+modelKey) req.Header.Set("Authorization", "Bearer "+model.ApiKey)
} }
if method != http.MethodGet { if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
} }
// 5)发送请求 // 6)发送请求
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
// 6)读取响应体 // 7)读取响应体
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 7)检查 HTTP 状态码 // 8)检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b) msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
@@ -468,27 +487,15 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, body map[string
// return mappedResponse, nil // return mappedResponse, nil
// } // }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
if err != nil {
return nil, fmt.Errorf("读取临时文件失败: %w", err)
}
_, ext := util.DetectFileType(data)
return gateway.UploadByTask(ctx, data, ext)
}
// failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调 // failTask 任务失败统一处理:更新数据库 + 释放排队 + 回调
func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, startTime time.Time, errMsg string) { func (w *asyncWorker) failTask(ctx context.Context, t *entity.ModelGatewayTask, startTime time.Time, errMsg string) {
t.State = 3 t.State = 3
t.ErrorMsg = errMsg t.ErrorMsg = errMsg
t.DurationSeconds = int64(time.Since(startTime).Seconds()) t.DurationSeconds = int64(time.Since(startTime).Seconds())
_ = dao.Task.UpdateFailedGlobal(ctx, t) _, err := dao.ModelGatewayTask.Update(ctx, t)
if err != nil {
g.Log().Warningf(ctx, "[执行任务][更新数据库失败] taskId=%s err=%v", t.TaskID, err)
}
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t) go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
} }
// rollbackToPending 恢复任务状态为 PENDING
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}

View File

@@ -1,285 +1,230 @@
-- model-asynch 核心表(pgsql)
-- 1) asynch_models模型配置
-- 2) asynch_task异步任务
-- 3) logs_model_op操作日志(统计用)
-- 4) logs_model_stat按天模型请求统计(限流/监控用)
-- ========================= -- =========================
-- 1) asynch_models -- model_gateway_models
-- ========================= -- =========================
CREATE TABLE IF NOT EXISTS asynch_models ( CREATE TABLE IF NOT EXISTS model_gateway_models (
-- ========== 基础字段 ========== id int8 PRIMARY KEY,
id BIGINT PRIMARY KEY, tenant_id int8 NOT NULL DEFAULT 0,
tenant_id BIGINT NOT NULL DEFAULT 0, creator varchar(64) NOT NULL,
creator VARCHAR(64) NOT NULL, created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updater varchar(64) NOT NULL,
updater VARCHAR(64) NOT NULL, updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, deleted_at timestamp(6),
deleted_at TIMESTAMP(6), model_name varchar(128) NOT NULL,
model_type int2 NOT NULL DEFAULT 0,
-- ========== 模型标识 ========== operator_name varchar(64) NOT NULL DEFAULT '',
model_name VARCHAR(128) NOT NULL, base_url varchar(256) NOT NULL,
model_type SMALLINT NOT NULL DEFAULT 0, http_method varchar(8) NOT NULL DEFAULT 'POST',
operator_name VARCHAR(64) NOT NULL DEFAULT '', head_msg jsonb NOT NULL DEFAULT '{}',
api_key varchar(256) NOT NULL DEFAULT '',
-- ========== 请求配置 ========== is_private int2 NOT NULL DEFAULT 0,
base_url VARCHAR(256) NOT NULL, enabled int2 NOT NULL DEFAULT 1,
http_method VARCHAR(8) NOT NULL DEFAULT 'POST', is_chat_model int2 NOT NULL DEFAULT 0,
head_msg JSONB NOT NULL DEFAULT '{}'::jsonb, is_owner int2 NOT NULL DEFAULT 99,
api_key VARCHAR(256) NOT NULL DEFAULT '', form_json jsonb NOT NULL DEFAULT '{}',
request_mapping jsonb NOT NULL DEFAULT '{}',
-- ========== 状态开关 ========== response_mapping jsonb NOT NULL DEFAULT '{}',
is_private SMALLINT NOT NULL DEFAULT 0, response_body varchar(128) NOT NULL DEFAULT '',
enabled SMALLINT NOT NULL DEFAULT 1, token_config jsonb NOT NULL DEFAULT '{}',
is_chat_model SMALLINT NOT NULL DEFAULT 0, extend_mapping jsonb NOT NULL DEFAULT '{}',
is_async SMALLINT NOT NULL DEFAULT 0, query_config jsonb NOT NULL DEFAULT '{}',
is_stream SMALLINT NOT NULL DEFAULT 0, stream_config jsonb NOT NULL DEFAULT '{}',
is_owner SMALLINT NOT NULL DEFAULT 99, first_frame varchar(128) NOT NULL DEFAULT '',
last_frame varchar(128) NOT NULL DEFAULT '',
-- ========== 配置相关 ========== max_concurrency int4 NOT NULL DEFAULT 10,
form_json JSONB NOT NULL DEFAULT '{}'::jsonb, timeout_seconds int4 NOT NULL DEFAULT 600,
request_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, retry_times int2 NOT NULL DEFAULT 3,
response_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, response_token_field varchar(128) NOT NULL DEFAULT '',
response_body JSONB NOT NULL DEFAULT '{}'::jsonb, call_mode int2 NOT NULL DEFAULT 0,
token_config JSONB NOT NULL DEFAULT '{}'::jsonb, required_fields jsonb NOT NULL DEFAULT '[]',
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb, max_tokens int4 DEFAULT 0
query_config JSONB NOT NULL DEFAULT '{}'::jsonb,
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb,
first_frame VARCHAR(128) NOT NULL DEFAULT '',
last_frame VARCHAR(128) NOT NULL DEFAULT '',
-- ========== 限制与重试 ==========
max_concurrency INT NOT NULL DEFAULT 10,
timeout_seconds INT NOT NULL DEFAULT 600,
retry_times SMALLINT NOT NULL DEFAULT 3,
auto_clean_seconds INT NOT NULL DEFAULT 86400,
-- ========== 其他 ==========
response_token_field VARCHAR(128) NOT NULL DEFAULT '',
); );
-- ========== 索引 ========== CREATE UNIQUE INDEX IF NOT EXISTS uk_model_gateway_models_tenant_creator_model ON model_gateway_models (tenant_id, creator, model_name);
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL; CREATE INDEX IF NOT EXISTS idx_model_gateway_models_model_name ON model_gateway_models (model_name);
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name); CREATE INDEX IF NOT EXISTS idx_model_gateway_models_model_type ON model_gateway_models (model_type);
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id); CREATE INDEX IF NOT EXISTS idx_model_gateway_models_tenant_id ON model_gateway_models (tenant_id);
CREATE INDEX IF NOT EXISTS idx_asynch_models_model_name ON asynch_models(model_name); CREATE INDEX IF NOT EXISTS idx_model_gateway_models_deleted_at ON model_gateway_models (deleted_at);
CREATE INDEX IF NOT EXISTS idx_asynch_models_model_type ON asynch_models(model_type); CREATE INDEX IF NOT EXISTS idx_model_gateway_models_enabled ON model_gateway_models (enabled);
CREATE INDEX IF NOT EXISTS idx_asynch_models_enabled ON asynch_models(enabled);
CREATE INDEX IF NOT EXISTS idx_asynch_models_deleted_at ON asynch_models(deleted_at);
-- ========== 注释 ========== COMMENT ON TABLE model_gateway_models IS '模型配置表';
COMMENT ON TABLE asynch_models IS '模型配置表'; COMMENT ON COLUMN model_gateway_models.id IS '主键ID(非自增)';
COMMENT ON COLUMN model_gateway_models.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_models.creator IS '创建人';
COMMENT ON COLUMN model_gateway_models.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_models.updater IS '更新人';
COMMENT ON COLUMN model_gateway_models.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_models.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN asynch_models.id IS '主键ID(非自增)'; COMMENT ON COLUMN model_gateway_models.model_name IS '模型名称';
COMMENT ON COLUMN asynch_models.tenant_id IS '租户ID'; COMMENT ON COLUMN model_gateway_models.model_type IS '模型类型';
COMMENT ON COLUMN asynch_models.creator IS '创建人'; COMMENT ON COLUMN model_gateway_models.operator_name IS '运营商名称';
COMMENT ON COLUMN asynch_models.created_at IS '创建时间'; COMMENT ON COLUMN model_gateway_models.base_url IS '模型地址';
COMMENT ON COLUMN asynch_models.updater IS '更新人'; COMMENT ON COLUMN model_gateway_models.http_method IS '请求方式 GET/POST';
COMMENT ON COLUMN asynch_models.updated_at IS '更新时间'; COMMENT ON COLUMN model_gateway_models.head_msg IS '请求头信息';
COMMENT ON COLUMN asynch_models.deleted_at IS '删除时间(软删)'; COMMENT ON COLUMN model_gateway_models.api_key IS '调用凭证/密钥';
COMMENT ON COLUMN model_gateway_models.is_private IS '是否私有化0-私有 1-公共';
COMMENT ON COLUMN asynch_models.model_name IS '模型名称'; COMMENT ON COLUMN model_gateway_models.enabled IS '是否启用0-停用 1-启用';
COMMENT ON COLUMN asynch_models.model_type IS '模型类型'; COMMENT ON COLUMN model_gateway_models.is_chat_model IS '是否为对话模型0-否 1-是';
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称'; COMMENT ON COLUMN model_gateway_models.is_owner IS '1=当前用户创建 0=超级管理员';
COMMENT ON COLUMN asynch_models.base_url IS '模型地址'; COMMENT ON COLUMN model_gateway_models.call_mode IS '调用模式0-同步 1-异步 2-流式';
COMMENT ON COLUMN asynch_models.http_method IS '请求方式 GET/POST'; COMMENT ON COLUMN model_gateway_models.form_json IS '动态表单结构';
COMMENT ON COLUMN asynch_models.head_msg IS '请求头信息'; COMMENT ON COLUMN model_gateway_models.request_mapping IS '请求映射';
COMMENT ON COLUMN asynch_models.api_key IS '调用凭证/密钥'; COMMENT ON COLUMN model_gateway_models.response_mapping IS '返回映射';
COMMENT ON COLUMN asynch_models.is_private IS '是否私有化0-私有 1-公共'; COMMENT ON COLUMN model_gateway_models.response_body IS '返回主体';
COMMENT ON COLUMN asynch_models.enabled IS '是否启用0-停用 1-启用'; COMMENT ON COLUMN model_gateway_models.token_config IS 'Token计算配置';
COMMENT ON COLUMN asynch_models.is_chat_model IS '是否为对话模型0-否 1-是'; COMMENT ON COLUMN model_gateway_models.extend_mapping IS '附加映射';
COMMENT ON COLUMN asynch_models.is_async IS '是否异步0-同步 1-异步'; COMMENT ON COLUMN model_gateway_models.query_config IS '查询/回调配置';
COMMENT ON COLUMN asynch_models.is_stream IS '是否流式0-非流式 1-流式'; COMMENT ON COLUMN model_gateway_models.stream_config IS '流式输出配置';
COMMENT ON COLUMN asynch_models.is_owner IS '1=当前用户创建 0=超级管理员'; COMMENT ON COLUMN model_gateway_models.first_frame IS '首帧图片参数';
COMMENT ON COLUMN asynch_models.form_json IS '动态表单结构'; COMMENT ON COLUMN model_gateway_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN asynch_models.request_mapping IS '请求映射'; COMMENT ON COLUMN model_gateway_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN asynch_models.response_mapping IS '返回映射'; COMMENT ON COLUMN model_gateway_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN asynch_models.response_body IS '返回主体'; COMMENT ON COLUMN model_gateway_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN asynch_models.token_config IS 'Token计算配置'; COMMENT ON COLUMN model_gateway_models.response_token_field IS '响应中消耗token的字段映射';
COMMENT ON COLUMN asynch_models.extend_mapping IS '附加映射'; COMMENT ON COLUMN model_gateway_models.required_fields IS '必选字段列表';
COMMENT ON COLUMN asynch_models.query_config IS '查询/回调配置'; COMMENT ON COLUMN model_gateway_models.max_tokens IS '最大 token 数0 表示不传';
COMMENT ON COLUMN asynch_models.stream_config IS '流式输出配置';
COMMENT ON COLUMN asynch_models.first_frame IS '首帧图片参数';
COMMENT ON COLUMN asynch_models.last_frame IS '尾帧图片参数';
COMMENT ON COLUMN asynch_models.max_concurrency IS '最大并发数';
COMMENT ON COLUMN asynch_models.timeout_seconds IS '调用模型超时(秒)';
COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '任务完成后自动清理时间(秒)';
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
-- =========================
-- 2) asynch_task
-- =========================
CREATE TABLE IF NOT EXISTS asynch_task (
-- 基础字段
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID
creator VARCHAR(64) NOT NULL, -- 创建人
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 创建时间
updater VARCHAR(64) NOT NULL, -- 更新人
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 更新时间
deleted_at TIMESTAMP(6), -- 删除时间(软删)
-- 业务字段
model_name VARCHAR(128) NOT NULL, -- 模型名称
task_id VARCHAR(64) NOT NULL, -- 任务ID(对外返回)
biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 业务名称(调用方模块/系统)
callback_url VARCHAR(512) DEFAULT '', -- 回调地址(可选,用于后续业务通知)
model_key VARCHAR(1024) DEFAULT '', -- 动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx
state SMALLINT NOT NULL DEFAULT 0, -- 0排队中/1执行中/2成功/3失败/4已下载
oss_file VARCHAR(512) DEFAULT '', -- 结果文件OSS地址
file_type VARCHAR(32) DEFAULT '', -- 文件类型(mp3/mp4/png/...)
file_size BIGINT NOT NULL DEFAULT 0, -- 文件大小(字节)
error_msg TEXT DEFAULT '', -- 错误信息
started_at TIMESTAMP, -- 开始执行时间
finished_at TIMESTAMP, -- 执行结束时间
duration_seconds BIGINT NOT NULL DEFAULT 0, -- 耗时(秒):从创建到完成(成功/失败)整体耗时
expire_at TIMESTAMP, -- state=4 后写入,用于清理
retry_count INT NOT NULL DEFAULT 0, -- 已重试次数(不含首次)
enqueue_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- 入队时间(用于排队顺序)
phase SMALLINT NOT NULL DEFAULT 0, -- 0模型阶段/1OSS阶段
tmp_file TEXT DEFAULT '', -- 临时结果文件路径(phase=1 时仅重试 OSS 上传)
input_ref TEXT DEFAULT '', -- 输入引用(如OSS/业务资源ID等)
request_payload JSONB, -- 请求参数(可选)
text_result TEXT DEFAULT '', -- 文本类结果(可选,支持直接回调)
epicycle_id VARCHAR(64) DEFAULT '', -- 轮次ID
expend_tokens BIGINT NOT NULL DEFAULT 0 -- 消耗 token 数
);
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_task_tenant_task_id ON asynch_task(tenant_id, task_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_tenant_id ON asynch_task(tenant_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_name ON asynch_task(model_name);
CREATE INDEX IF NOT EXISTS idx_asynch_task_biz_name ON asynch_task(biz_name);
CREATE INDEX IF NOT EXISTS idx_asynch_task_model_key ON asynch_task(model_key);
CREATE INDEX IF NOT EXISTS idx_asynch_task_state ON asynch_task(state);
CREATE INDEX IF NOT EXISTS idx_asynch_task_enqueue_at ON asynch_task(enqueue_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_updated_at ON asynch_task(updated_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expire_at ON asynch_task(expire_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_deleted_at ON asynch_task(deleted_at);
CREATE INDEX IF NOT EXISTS idx_asynch_task_epicycle_id ON asynch_task(epicycle_id);
CREATE INDEX IF NOT EXISTS idx_asynch_task_expend_tokens ON asynch_task(expend_tokens);
COMMENT ON TABLE asynch_task IS '异步任务表';
COMMENT ON COLUMN asynch_task.id IS '主键ID(非自增)';
COMMENT ON COLUMN asynch_task.tenant_id IS '租户ID';
COMMENT ON COLUMN asynch_task.creator IS '创建人';
COMMENT ON COLUMN asynch_task.created_at IS '创建时间';
COMMENT ON COLUMN asynch_task.updater IS '更新人';
COMMENT ON COLUMN asynch_task.updated_at IS '更新时间';
COMMENT ON COLUMN asynch_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN asynch_task.model_name IS '模型名称';
COMMENT ON COLUMN asynch_task.task_id IS '任务ID(对外返回)';
COMMENT ON COLUMN asynch_task.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN asynch_task.callback_url IS '回调地址(可选,用于后续业务通知)';
COMMENT ON COLUMN asynch_task.model_key IS '动态请求头(用于覆盖/补充模型配置 head_msg),如 X-API-Key:xxx';
COMMENT ON COLUMN asynch_task.state IS '0排队中/1执行中/2成功/3失败/4已下载';
COMMENT ON COLUMN asynch_task.oss_file IS '结果文件OSS地址';
COMMENT ON COLUMN asynch_task.file_type IS '文件类型(mp3/mp4/png/...)';
COMMENT ON COLUMN asynch_task.file_size IS '文件大小(字节)';
COMMENT ON COLUMN asynch_task.error_msg IS '错误信息';
COMMENT ON COLUMN asynch_task.started_at IS '开始执行时间';
COMMENT ON COLUMN asynch_task.finished_at IS '执行结束时间';
COMMENT ON COLUMN asynch_task.duration_seconds IS '耗时(秒):从创建到完成(成功/失败)整体耗时';
COMMENT ON COLUMN asynch_task.expire_at IS 'state=4 后写入,用于清理';
COMMENT ON COLUMN asynch_task.retry_count IS '已重试次数(不含首次)';
COMMENT ON COLUMN asynch_task.enqueue_at IS '入队时间(用于排队顺序)';
COMMENT ON COLUMN asynch_task.phase IS '执行阶段 模型阶段/1OSS阶段(模型已成功,等待上传OSS)';
COMMENT ON COLUMN asynch_task.tmp_file IS '临时结果文件路径(phase=1 时仅重试 OSS 上传)';
COMMENT ON COLUMN asynch_task.input_ref IS '输入引用(如OSS/业务资源ID等)';
COMMENT ON COLUMN asynch_task.request_payload IS '请求参数(可选,JSON)';
COMMENT ON COLUMN asynch_task.text_result IS '文本类结果(可选,支持直接回调)';
COMMENT ON COLUMN asynch_task.epicycle_id IS '轮次ID(用于标识同一轮次的任务)';
COMMENT ON COLUMN asynch_task.expend_tokens IS '消耗 token 数';
-- ========================= -- =========================
-- 3) logs_model_op -- model_gateway_task
-- ========================= -- =========================
CREATE TABLE IF NOT EXISTS logs_model_op ( CREATE TABLE IF NOT EXISTS model_gateway_task (
-- 基础字段 id int8 PRIMARY KEY,
id BIGINT PRIMARY KEY, tenant_id int8 NOT NULL DEFAULT 0,
tenant_id BIGINT NOT NULL DEFAULT 0, creator varchar(64) NOT NULL,
creator VARCHAR(64) NOT NULL, created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updater varchar(64) NOT NULL,
updater VARCHAR(64) NOT NULL, updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, deleted_at timestamp(6),
deleted_at TIMESTAMP(6), model_name varchar(128) NOT NULL,
-- 基础审计信息 task_id varchar(64) NOT NULL,
ip VARCHAR(64) DEFAULT '', biz_name varchar(128) NOT NULL DEFAULT '',
user_agent VARCHAR(256) DEFAULT '', callback_url varchar(512) DEFAULT '',
api_path VARCHAR(256) DEFAULT '', state int2 NOT NULL DEFAULT 0,
http_method VARCHAR(16) DEFAULT '', retry_count int4 NOT NULL DEFAULT 0,
-- 业务信息 phase int2 NOT NULL DEFAULT 0,
biz_name VARCHAR(128) NOT NULL DEFAULT '', -- 调用方业务模块/系统 tmp_file text DEFAULT '',
model_name VARCHAR(128) NOT NULL DEFAULT '', error_msg text DEFAULT '',
task_id VARCHAR(64) NOT NULL DEFAULT '', result_file jsonb NOT NULL DEFAULT '{}',
-- 统计字段 request_payload jsonb NOT NULL DEFAULT '{}',
op_type VARCHAR(64) NOT NULL DEFAULT 'createTask', -- 操作类型(默认创建任务) text_result jsonb NOT NULL DEFAULT '{}',
success SMALLINT NOT NULL DEFAULT 1, -- 1成功/0失败 expend_tokens int8 NOT NULL DEFAULT 0,
error_msg TEXT DEFAULT '', duration_seconds int8 NOT NULL DEFAULT 0,
cost_ms BIGINT NOT NULL DEFAULT 0, -- 耗时(毫秒) epicycle_id varchar(64) NOT NULL DEFAULT ''
-- 请求/响应 JSON(用于后期统计分析) );
request_payload JSONB,
response_payload JSONB
);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_tenant_time ON logs_model_op(tenant_id, created_at); CREATE UNIQUE INDEX IF NOT EXISTS uk_model_gateway_task_tenant_creator_task_id ON model_gateway_task (tenant_id, creator, task_id);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_model_name ON logs_model_op(model_name); CREATE INDEX IF NOT EXISTS idx_model_gateway_task_task_id ON model_gateway_task (task_id);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_biz_name ON logs_model_op(biz_name); CREATE INDEX IF NOT EXISTS idx_model_gateway_task_state ON model_gateway_task (state);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_task_id ON logs_model_op(task_id); CREATE INDEX IF NOT EXISTS idx_model_gateway_task_deleted_at ON model_gateway_task (deleted_at);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_op_type ON logs_model_op(op_type);
CREATE INDEX IF NOT EXISTS idx_logs_model_op_deleted_at ON logs_model_op(deleted_at); COMMENT ON TABLE model_gateway_task IS '模型网关任务表';
COMMENT ON COLUMN model_gateway_task.id IS '主键ID';
COMMENT ON COLUMN model_gateway_task.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_task.creator IS '创建人';
COMMENT ON COLUMN model_gateway_task.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_task.updater IS '更新人';
COMMENT ON COLUMN model_gateway_task.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_task.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_task.task_id IS '任务ID对外返回';
COMMENT ON COLUMN model_gateway_task.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN model_gateway_task.callback_url IS '回调地址';
COMMENT ON COLUMN model_gateway_task.state IS '0排队中/1执行中/2成功/3失败/4已下载';
COMMENT ON COLUMN model_gateway_task.retry_count IS '已重试次数';
COMMENT ON COLUMN model_gateway_task.phase IS '执行阶段0模型阶段/1OSS阶段';
COMMENT ON COLUMN model_gateway_task.tmp_file IS '临时结果文件路径';
COMMENT ON COLUMN model_gateway_task.error_msg IS '错误信息';
COMMENT ON COLUMN model_gateway_task.result_file IS '结果文件:{oss_file, file_type, file_size}';
COMMENT ON COLUMN model_gateway_task.request_payload IS '请求参数JSON';
COMMENT ON COLUMN model_gateway_task.text_result IS '文本类结果';
COMMENT ON COLUMN model_gateway_task.expend_tokens IS '消耗token数';
COMMENT ON COLUMN model_gateway_task.duration_seconds IS '耗时(秒)';
COMMENT ON COLUMN model_gateway_task.epicycle_id IS '轮次ID';
COMMENT ON TABLE logs_model_op IS '操作记录日志表(创建任务等,用于统计)';
COMMENT ON COLUMN logs_model_op.id IS '主键ID(非自增)';
COMMENT ON COLUMN logs_model_op.tenant_id IS '租户ID';
COMMENT ON COLUMN logs_model_op.creator IS '创建人';
COMMENT ON COLUMN logs_model_op.created_at IS '创建时间';
COMMENT ON COLUMN logs_model_op.updater IS '更新人';
COMMENT ON COLUMN logs_model_op.updated_at IS '更新时间';
COMMENT ON COLUMN logs_model_op.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN logs_model_op.ip IS '客户端IP';
COMMENT ON COLUMN logs_model_op.user_agent IS 'User-Agent';
COMMENT ON COLUMN logs_model_op.api_path IS '接口路径';
COMMENT ON COLUMN logs_model_op.http_method IS 'HTTP方法';
COMMENT ON COLUMN logs_model_op.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN logs_model_op.model_name IS '模型名称';
COMMENT ON COLUMN logs_model_op.task_id IS '任务ID';
COMMENT ON COLUMN logs_model_op.op_type IS '操作类型(如 createTask/getTaskResult/getTaskBatch 等)';
COMMENT ON COLUMN logs_model_op.success IS '是否成功1成功/0失败';
COMMENT ON COLUMN logs_model_op.error_msg IS '错误信息(失败时)';
COMMENT ON COLUMN logs_model_op.cost_ms IS '耗时(毫秒)';
COMMENT ON COLUMN logs_model_op.request_payload IS '请求 JSON';
COMMENT ON COLUMN logs_model_op.response_payload IS '响应 JSON';
-- ========================= -- =========================
-- 4) logs_model_stat -- model_gateway_log_stat
-- ========================= -- =========================
CREATE TABLE IF NOT EXISTS logs_model_stat ( CREATE TABLE IF NOT EXISTS model_gateway_log_stat (
day DATE NOT NULL, -- 天(YYYY-MM-DD) day date NOT NULL,
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID tenant_id int8 NOT NULL DEFAULT 0,
creator VARCHAR(64) NOT NULL DEFAULT '', -- 创建人 creator varchar(64) NOT NULL DEFAULT '',
model_name VARCHAR(128) NOT NULL DEFAULT '', -- 模型名称 model_name varchar(128) NOT NULL DEFAULT '',
request_count BIGINT NOT NULL DEFAULT 0, -- 请求次数 request_count int8 NOT NULL DEFAULT 0,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY(day, tenant_id, creator, model_name) PRIMARY KEY (day, tenant_id, creator, model_name)
); );
-- 便于时间段/租户/人/模型过滤 CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_day ON model_gateway_log_stat (day);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_tenant_day ON logs_model_stat(tenant_id, day); CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_creator ON model_gateway_log_stat (creator);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_day ON logs_model_stat(day); CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_model_name ON model_gateway_log_stat (model_name);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_model_name ON logs_model_stat(model_name); CREATE INDEX IF NOT EXISTS idx_model_gateway_log_stat_tenant_day ON model_gateway_log_stat (tenant_id, day);
CREATE INDEX IF NOT EXISTS idx_logs_model_stat_creator ON logs_model_stat(creator);
COMMENT ON TABLE logs_model_stat IS '按天模型请求统计(用于限流/监控)'; COMMENT ON TABLE model_gateway_log_stat IS '按天统计表';
COMMENT ON COLUMN logs_model_stat.day IS '(YYYY-MM-DD)'; COMMENT ON COLUMN model_gateway_log_stat.day IS 'YYYY-MM-DD';
COMMENT ON COLUMN logs_model_stat.tenant_id IS '租户ID'; COMMENT ON COLUMN model_gateway_log_stat.tenant_id IS '租户ID';
COMMENT ON COLUMN logs_model_stat.creator IS '创建人'; COMMENT ON COLUMN model_gateway_log_stat.creator IS '创建人';
COMMENT ON COLUMN logs_model_stat.model_name IS '模型名称'; COMMENT ON COLUMN model_gateway_log_stat.model_name IS '模型名称';
COMMENT ON COLUMN logs_model_stat.request_count IS '请求次数'; COMMENT ON COLUMN model_gateway_log_stat.request_count IS '请求次数';
COMMENT ON COLUMN logs_model_stat.created_at IS '创建时间'; COMMENT ON COLUMN model_gateway_log_stat.created_at IS '创建时间';
COMMENT ON COLUMN logs_model_stat.updated_at IS '更新时间'; COMMENT ON COLUMN model_gateway_log_stat.updated_at IS '更新时间';
-- =========================
-- model_gateway_logs_op
-- =========================
CREATE TABLE IF NOT EXISTS model_gateway_logs_op (
id int8 PRIMARY KEY,
tenant_id int8 NOT NULL DEFAULT 0,
creator varchar(64) NOT NULL,
created_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater varchar(64) NOT NULL,
updated_at timestamp(6) NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
ip varchar(64) DEFAULT '',
user_agent varchar(256) DEFAULT '',
api_path varchar(256) DEFAULT '',
http_method varchar(16) DEFAULT '',
biz_name varchar(128) NOT NULL DEFAULT '',
model_name varchar(128) NOT NULL DEFAULT '',
task_id varchar(64) NOT NULL DEFAULT '',
op_type varchar(64) NOT NULL DEFAULT 'createTask',
success int2 NOT NULL DEFAULT 1,
error_msg text DEFAULT '',
cost_ms int8 NOT NULL DEFAULT 0,
request_payload jsonb,
response_payload jsonb
);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_task_id ON model_gateway_logs_op (task_id);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_biz_name ON model_gateway_logs_op (biz_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_model_name ON model_gateway_logs_op (model_name);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_op_type ON model_gateway_logs_op (op_type);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_deleted_at ON model_gateway_logs_op (deleted_at);
CREATE INDEX IF NOT EXISTS idx_model_gateway_logs_op_tenant_time ON model_gateway_logs_op (tenant_id, created_at);
COMMENT ON TABLE model_gateway_logs_op IS '操作日志表';
COMMENT ON COLUMN model_gateway_logs_op.id IS '主键ID非自增';
COMMENT ON COLUMN model_gateway_logs_op.tenant_id IS '租户ID';
COMMENT ON COLUMN model_gateway_logs_op.creator IS '创建人';
COMMENT ON COLUMN model_gateway_logs_op.created_at IS '创建时间';
COMMENT ON COLUMN model_gateway_logs_op.updater IS '更新人';
COMMENT ON COLUMN model_gateway_logs_op.updated_at IS '更新时间';
COMMENT ON COLUMN model_gateway_logs_op.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN model_gateway_logs_op.ip IS '客户端IP';
COMMENT ON COLUMN model_gateway_logs_op.user_agent IS 'User-Agent';
COMMENT ON COLUMN model_gateway_logs_op.api_path IS '接口路径';
COMMENT ON COLUMN model_gateway_logs_op.http_method IS 'HTTP方法';
COMMENT ON COLUMN model_gateway_logs_op.biz_name IS '业务名称(调用方模块/系统)';
COMMENT ON COLUMN model_gateway_logs_op.model_name IS '模型名称';
COMMENT ON COLUMN model_gateway_logs_op.task_id IS '任务ID';
COMMENT ON COLUMN model_gateway_logs_op.op_type IS '操作类型';
COMMENT ON COLUMN model_gateway_logs_op.success IS '是否成功1成功/0失败';
COMMENT ON COLUMN model_gateway_logs_op.error_msg IS '错误信息(失败时)';
COMMENT ON COLUMN model_gateway_logs_op.cost_ms IS '耗时(毫秒)';
COMMENT ON COLUMN model_gateway_logs_op.request_payload IS '请求 JSON';
COMMENT ON COLUMN model_gateway_logs_op.response_payload IS '响应 JSON';