diff --git a/common/util/files.go b/common/util/files.go new file mode 100644 index 0000000..34268a1 --- /dev/null +++ b/common/util/files.go @@ -0,0 +1,69 @@ +package util + +import ( + "fmt" + "net/http" + "os" + "path/filepath" + "strings" +) + +// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定) +func DetectFileType(data []byte) (contentType string, ext string) { + if len(data) == 0 { + return "application/octet-stream", "" + } + ct := http.DetectContentType(data) + // gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8 + if idx := strings.Index(ct, ";"); idx > 0 { + ct = strings.TrimSpace(ct[:idx]) + } + switch ct { + case "audio/mpeg": + return ct, ".mp3" + case "audio/wave", "audio/wav", "audio/x-wav": + return ct, ".wav" + case "video/mp4": + return ct, ".mp4" + case "image/png": + return ct, ".png" + case "image/jpeg": + return ct, ".jpg" + case "application/pdf": + return ct, ".pdf" + case "text/plain": + return ct, ".txt" + case "application/json": + return ct, ".json" + default: + // 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json) + if parts := strings.Split(ct, "/"); len(parts) == 2 { + sub := parts[1] + // 避免出现 "plain; charset=utf-8" 之类的后缀 + if idx := strings.Index(sub, ";"); idx > 0 { + sub = strings.TrimSpace(sub[:idx]) + } + return ct, "." + sub + } + return ct, "" + } +} + +// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 +func SaveTmpResult(taskID string, data []byte, ext string) (string, error) { + dir := filepath.Join(os.TempDir(), "model-asynch") + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + if ext == "" { + ext = ".bin" + } + if ext[0] != '.' { + ext = "." + ext + } + path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) + if err := os.WriteFile(path, data, 0o644); err != nil { + return "", err + } + return path, nil +} diff --git a/common/util/headers.go b/common/util/headers.go new file mode 100644 index 0000000..c723646 --- /dev/null +++ b/common/util/headers.go @@ -0,0 +1,100 @@ +package util + +import ( + "context" + + "gitea.com/red-future/common/utils" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失 +func AsyncCtx(ctx context.Context) context.Context { + asyncCtx := context.WithoutCancel(ctx) + + if r := g.RequestFromCtx(ctx); r != nil { + if token := r.Header.Get("Authorization"); token != "" { + asyncCtx = context.WithValue(asyncCtx, "token", token) + } + if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { + asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo) + } + } + + if user, err := utils.GetUserInfo(ctx); err == nil && user != nil { + asyncCtx = context.WithValue(asyncCtx, "user", user) + } + + return asyncCtx +} + +// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值 +func ForwardHeaders(ctx context.Context) map[string]string { + headers := make(map[string]string) + SetHeaderFromContext(headers, ctx, "Authorization", "token") + SetHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo") + FallbackToRequestHeaders(headers, ctx) + return headers +} + +// SetHeaderFromContext 从上下文中设置 header +func SetHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) { + if value, ok := ctx.Value(ctxKey).(string); ok && value != "" { + headers[headerKey] = value + } +} + +// FallbackToRequestHeaders 从请求头中获取作为兜底 +func FallbackToRequestHeaders(headers map[string]string, ctx context.Context) { + r := g.RequestFromCtx(ctx) + if r == nil { + return + } + + if headers["Authorization"] == "" { + if token := r.Header.Get("Authorization"); token != "" { + headers["Authorization"] = token + } + } + + if headers["X-User-Info"] == "" { + if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { + headers["X-User-Info"] = userInfo + } + } +} + +// SetTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用 +func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context { + if headers == nil { + return ctx + } + if v := gconv.String(headers["Authorization"]); v != "" { + ctx = context.WithValue(ctx, "token", v) + } + if v := gconv.String(headers["X-User-Info"]); v != "" { + ctx = context.WithValue(ctx, "xUserInfo", v) + } + return ctx +} + +// ParseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers +// 入库格式:{"payload": , "headers": {"Authorization": "...", "X-User-Info":"..."}} +func ParseStoredPayload(v any) (payload any, headers map[string]string) { + if v == nil { + return nil, nil + } + m := gconv.Map(v) + if len(m) == 0 { + return v, nil + } + if h, ok := m["headers"]; ok { + headers = gconv.MapStrStr(h) + } + if p, ok := m["payload"]; ok { + payload = p + } else { + payload = v + } + return +} diff --git a/common/util/json.go b/common/util/json.go new file mode 100644 index 0000000..2da838b --- /dev/null +++ b/common/util/json.go @@ -0,0 +1,28 @@ +package util + +import ( + "encoding/json" + + "github.com/gogf/gf/v2/container/gvar" +) + +func ParseJSONField(field any) any { + var v *gvar.Var + switch val := field.(type) { + case *gvar.Var: + v = val + default: + return field + } + + if v == nil || v.IsNil() || v.IsEmpty() { + return nil + } + + str := v.String() + var result any + if json.Unmarshal([]byte(str), &result) == nil { + return result + } + return str +} diff --git a/config.yml b/config.yml index 8491f74..5e5be98 100644 --- a/config.yml +++ b/config.yml @@ -26,6 +26,27 @@ database: updatedAt: "updated_at" # (可选)自动更新时间字段名称 deletedAt: "deleted_at" # (可选)软删除时间字段名称 timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效 + model_gateway: + - type: "pgsql" + host: "116.204.74.41" + port: "15432" + user: "postgres" + pass: "Bjang09@686^*^" + name: "model-gateway" + prefix: "" + role: "master" + debug: true + dryRun: false + charset: "utf8" + timezone: "Asia/Shanghai" + maxIdle: 5 + maxOpen: 20 + maxLifetime: "30s" + maxIdleConnTime: "30s" + createdAt: "created_at" + updatedAt: "updated_at" + deletedAt: "deleted_at" + timeMaintainDisabled: false redis: default: @@ -48,11 +69,3 @@ asynch: cleaner: enabled: false intervalSeconds: 30 - -modelType: - types: - 1: "推理模型" - 2: "图片模型" - 3: "音频模型" - 4: "向量化模型" - 5: "全模态模型" diff --git a/consts/public/public.go b/consts/public/public.go new file mode 100644 index 0000000..fcddd35 --- /dev/null +++ b/consts/public/public.go @@ -0,0 +1,19 @@ +package public + +// ModelType 模型类型常量 +const ( + ModelTypeInference = 1 // 推理模型 + ModelTypeImage = 2 // 图片模型 + ModelTypeAudio = 3 // 音频模型 + ModelTypeVector = 4 // 向量化模型 + ModelTypeOmni = 5 // 全模态模型 +) + +// ModelTypeName 模型类型名称映射 +var ModelTypeName = map[int]string{ + ModelTypeInference: "推理模型", + ModelTypeImage: "图片模型", + ModelTypeAudio: "音频模型", + ModelTypeVector: "向量化模型", + ModelTypeOmni: "全模态模型", +} diff --git a/consts/public/table_name.go b/consts/public/table_name.go index 8ee0bdf..16cf1c0 100644 --- a/consts/public/table_name.go +++ b/consts/public/table_name.go @@ -1,5 +1,9 @@ package public +const ( + DbNameModelGateway = "model_gateway" //数据库名称 +) + const ( TableNameModel = "asynch_models" // 模型表 TableNameTask = "asynch_task" // 任务表 diff --git a/controller/model_controller.go b/controller/model_controller.go index 204e737..619b49e 100644 --- a/controller/model_controller.go +++ b/controller/model_controller.go @@ -4,10 +4,7 @@ import ( "context" "model-gateway/model/dto" - "model-gateway/model/entity" "model-gateway/service" - - "gitea.com/red-future/common/beans" ) type model struct{} @@ -21,67 +18,44 @@ func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res * } // UpdateModel 更改配置 -func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) { +func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) { err = service.Model.Update(ctx, req) return } // DeleteModel 删除配置 -func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) { - err = service.Model.Delete(ctx, req.ID) +func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) { + err = service.Model.Delete(ctx, req) return } -// GetModel 获取配置详情(按 modelName) +// GetModel 获取配置详情 func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) { - model, err := service.Model.Get(ctx, req.ID) - if err != nil { - return nil, err - } - if model == nil { - return nil, nil - } - return &dto.GetModelRes{Model: model}, nil + return service.Model.Get(ctx, req) } // ListModel 配置列表 func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { - list, total, err := service.Model.List(ctx, req) - if err != nil { - return nil, err - } - return &dto.ListModelRes{ - List: list, - Total: total, - }, nil + return service.Model.List(ctx, req) } // AutoTune 动态调参(由上层定时任务每小时触发一次) func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) { - windowSeconds := 3600 - if req != nil && req.WindowSeconds > 0 { - windowSeconds = req.WindowSeconds - } - list, err := service.AutoTune(ctx, windowSeconds) - if err != nil { - return nil, err - } - return &dto.AutoTuneRes{List: list}, nil + return service.AutoTune(ctx, req) } -func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res dto.TypeItem, err error) { - modelType := service.GetModelTypesFromConfig(ctx) - res.Type = modelType - return res, nil +// ListType 模型类型列表 +func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) { + return service.GetModelTypesFromConfig() } // UpdateChatModel 更新是否为聊天模型 -func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *beans.ResponseEmpty, err error) { +func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) { err = service.Model.UpdateChatModel(ctx, req) return } -// GetIsChatModel 获取是否为聊天模型 -func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *entity.AsynchModel, err error) { +// GetIsChatModel 获取当前会话模型 +func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) { return service.Model.GetIsChatModel(ctx) } diff --git a/controller/task_controller.go b/controller/task_controller.go index f305ae8..f62a7bb 100644 --- a/controller/task_controller.go +++ b/controller/task_controller.go @@ -34,24 +34,10 @@ func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.Lis // RunWork 手动触发一次 worker(由上层定时任务调用) func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { - batchSize, goroutines := 10, 1 - if req != nil { - if req.BatchSize > 0 { - batchSize = req.BatchSize - } - if req.Goroutines > 0 { - goroutines = req.Goroutines - } - } - n, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines) - if err != nil { - return nil, err - } - return &dto.RunWorkRes{Claimed: n}, nil + return service.AsyncWorker.RunOnce(ctx, req) } // CleanWork 手动触发一次 cleaner(由上层定时任务调用) func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) { - service.Cleaner.RunOnce(ctx) - return &dto.CleanWorkRes{Ok: true}, nil + return service.Cleaner.RunOnce(ctx) } diff --git a/dao/model_dao.go b/dao/model_dao.go index c0fbf2d..92e6d91 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -2,14 +2,11 @@ package dao import ( "context" - "fmt" - "model-gateway/consts/public" "model-gateway/model/dto" "model-gateway/model/entity" "gitea.com/red-future/common/db/gfdb" - "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) @@ -18,157 +15,80 @@ var Model = &modelDao{} type modelDao struct{} -func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) { - asyncModel := new(entity.AsynchModel) - err = gconv.Struct(req, &asyncModel) +// Insert 插入 +func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) { + m := new(entity.AsynchModel) + err = gconv.Struct(req, &m) if err != nil { return } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert() + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + Insert(m) if err != nil { - return 0, err + return } return r.LastInsertId() } -func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) { - // 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段 - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). +// Update 更新 +func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). OmitEmpty(). - Where(entity.AsynchModelCol.Id, m.ID). - Data(m). + Data(&req). + Where(entity.AsynchModelCol.Id, req.Id). Update() if err != nil { - return 0, err + return } return r.RowsAffected() } -func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - Where(entity.AsynchModelCol.Id, id). +// Delete 删除 +func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + OmitEmpty(). + Where(entity.AsynchModelCol.Id, req.Id). Delete() if err != nil { - return 0, err + return } return r.RowsAffected() } -func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - Where(entity.AsynchModelCol.ModelName, modelName). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - NoTenantId(ctx). - Where(entity.AsynchModelCol.Id, id). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&m) - return -} - -func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) { - count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty(). +// Get 按ID获取(带租户隔离,只查当前租户) +func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + OmitEmpty(). + Where(entity.AsynchModelCol.Id, req.Id). Where(entity.AsynchModelCol.Creator, req.Creator). - Where(entity.AsynchModelCol.Id, req.ID).Count() - return -} - -func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) { - model := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - OrderDesc(entity.AsynchModelCol.CreatedAt) - if modelNameLike != "" { - model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%") - } - if modelType != 0 { - model = model.Where(entity.AsynchModelCol.ModelType, modelType) - } - if isPrivate != 0 { - model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate) - } - if pageNum > 0 && pageSize > 0 { - model = model.Page(pageNum, pageSize) - } - r, totalInt, err := model.AllAndCount(false) + Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). + Where(entity.AsynchModelCol.ModelName, req.ModelName). + Fields(fields).One() if err != nil { - return nil, 0, err - } - total = gconv.Int64(totalInt) - err = r.Structs(&list) - return -} - -func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) { - userInfo, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - Where(entity.AsynchModelCol.IsChatModel, 1). - Where(entity.AsynchModelCol.Creator, userInfo.UserName). - One() - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil + return } err = r.Struct(&m) return } -// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx) -func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) { - // 构建 Where 条件 - whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员 - args := []any{creator} - - if modelNameLike != "" { - whereSQL += " AND model_name LIKE ?" - args = append(args, "%"+modelNameLike+"%") - } - - // 查总数 - countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL) - countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...) +// GetByAcrossTenant 按ID获取(跨租户,查所有租户) +func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel). + NoTenantId(ctx). + OmitEmpty(). + Where(entity.AsynchModelCol.Id, req.Id). + Where(entity.AsynchModelCol.Creator, req.Creator). + Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel). + Where(entity.AsynchModelCol.ModelName, req.ModelName). + Fields(fields).One() if err != nil { - return nil, 0, err + return } - if len(countResult) > 0 { - total = gconv.Int64(countResult[0]["count"]) - } - - // 查列表 - querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL) - if pageNum > 0 && pageSize > 0 { - offset := (pageNum - 1) * pageSize - querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset) - } - - r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...) - if err != nil { - return nil, 0, err - } - - err = r.Structs(&list) + err = r.Struct(&m) return } + +// GetByCreatorAndPlatform 按创建者、平台获取 func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { // 基础 SQL sql := ` @@ -212,7 +132,7 @@ WHERE deleted_at IS NULL // 最后拼接排序 sql += ` ORDER BY model_name, is_owner DESC, created_at DESC` - r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...) + r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) if err != nil { return nil, 0, err } @@ -226,14 +146,24 @@ WHERE deleted_at IS NULL return } -// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理) -func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel). - OrderDesc(entity.AsynchModelCol.CreatedAt). - All() +// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文 +func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, + "SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1", + tenantId, modelName, + ) if err != nil { return nil, err } - err = r.Structs(&list) - return + if r.IsEmpty() { + return nil, nil + } + var list []*entity.AsynchModel + if err := r.Structs(&list); err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + return list[0], nil } diff --git a/dao/model_dao_bg.go b/dao/model_dao_bg.go deleted file mode 100644 index beddd2f..0000000 --- a/dao/model_dao_bg.go +++ /dev/null @@ -1,32 +0,0 @@ -package dao - -import ( - "context" - - "model-gateway/consts/public" - "model-gateway/model/entity" - - "gitea.com/red-future/common/db/gfdb" -) - -// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文 -func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) { - r, err := gfdb.DB(ctx).GetAll(ctx, - "SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1", - tenantId, modelName, - ) - if err != nil { - return nil, err - } - if r.IsEmpty() { - return nil, nil - } - var list []*entity.AsynchModel - if err := r.Structs(&list); err != nil { - return nil, err - } - if len(list) == 0 { - return nil, nil - } - return list[0], nil -} diff --git a/dao/op_log_dao.go b/dao/op_log_dao.go index 6888827..8880d77 100644 --- a/dao/op_log_dao.go +++ b/dao/op_log_dao.go @@ -7,14 +7,22 @@ import ( "model-gateway/model/entity" "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/util/gconv" ) type opLogDao struct{} var OpLog = &opLogDao{} -func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert() +// Insert 插入 +func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) { + m := new(entity.LogsModelOp) + err = gconv.Struct(req, &m) + if err != nil { + return + } + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog). + Insert(m) if err != nil { return 0, err } diff --git a/dao/stat_dao.go b/dao/stat_dao.go index e6cdfdc..c523cec 100644 --- a/dao/stat_dao.go +++ b/dao/stat_dao.go @@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name) DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`, public.TableNameStat, public.TableNameStat, ) - _, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName) + _, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName) return err } diff --git a/dao/task_dao.go b/dao/task_dao.go index fcaaffd..b5b8f26 100644 --- a/dao/task_dao.go +++ b/dao/task_dao.go @@ -2,9 +2,6 @@ package dao import ( "context" - "fmt" - "time" - "model-gateway/consts/public" "model-gateway/model/entity" @@ -18,40 +15,47 @@ var Task = &taskDao{} type taskDao struct{} -func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert() +// Insert 插入 +func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) { + m := new(entity.AsynchTask) + err = gconv.Struct(req, &m) if err != nil { - return 0, err + return + } + r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). + Insert(m) + if err != nil { + return } return r.LastInsertId() } -func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.TaskID, taskID). - One() +// Get 获取 +func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) { + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + OmitEmpty(). + Where(entity.AsynchTaskCol.TaskID, req.TaskID). + Fields(fields).One() if err != nil { - return nil, err + return } - if r.IsEmpty() { - return nil, nil - } - err = r.Struct(&t) + err = r.Struct(&m) return } // ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据) -func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) { +func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) { if len(taskIDs) == 0 { return nil, nil } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). + r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). + OmitEmpty(). WhereIn(entity.AsynchTaskCol.TaskID, taskIDs). All() if err != nil { return nil, err } - err = r.Structs(&list) + err = r.Structs(&m) return } @@ -62,7 +66,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt entity.AsynchTaskCol.ExpireAt: expireAt, entity.AsynchTaskCol.Updater: "", } - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). + _, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask). Where(entity.AsynchTaskCol.Id, id). Where(entity.AsynchTaskCol.State, 2). Data(data). @@ -70,73 +74,6 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt return err } -func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error { - now := gtime.Now() - data := gdb.Map{ - entity.AsynchTaskCol.State: 1, - entity.AsynchTaskCol.StartedAt: now, - entity.AsynchTaskCol.Updater: "", - } - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.Id, id). - Data(data). - Update() - return err -} - -func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error { - now := gtime.Now() - data := gdb.Map{ - entity.AsynchTaskCol.State: 2, - entity.AsynchTaskCol.OssFile: ossFile, - entity.AsynchTaskCol.FileType: fileType, - entity.AsynchTaskCol.FileSize: fileSize, - entity.AsynchTaskCol.ErrorMsg: "", - entity.AsynchTaskCol.FinishedAt: now, - entity.AsynchTaskCol.ExpireAt: expireAt, - entity.AsynchTaskCol.Updater: "", - } - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.Id, id). - Data(data). - Update() - return err -} - -func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error { - now := gtime.Now() - data := gdb.Map{ - entity.AsynchTaskCol.State: 3, - entity.AsynchTaskCol.ErrorMsg: errorMsg, - entity.AsynchTaskCol.FinishedAt: now, - entity.AsynchTaskCol.Updater: "", - } - _, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.Id, id). - Data(data). - Update() - return err -} - -func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) { - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.TaskID, taskID). - Delete() - if err != nil { - return 0, err - } - return r.RowsAffected() -} - -// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值) -func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) { - n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.ModelName, modelName). - WhereIn(entity.AsynchTaskCol.State, []int{0, 1}). - Count() - return int64(n), err -} - // List 任务分页查询(受 gfdb 租户 Hook 影响) func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) { m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL") @@ -161,90 +98,3 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike err = r.Structs(&list) return } - -// ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1) -// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费 -func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) { - if batchSize <= 0 { - batchSize = 1 - } - err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - sql := fmt.Sprintf( - `SELECT id, tenant_id, model_name, task_id, input_ref, request_payload - FROM %s - WHERE deleted_at IS NULL AND state = 0 - ORDER BY created_at ASC - LIMIT %d - FOR UPDATE SKIP LOCKED`, - public.TableNameTask, - batchSize, - ) - r, err := tx.GetAll(sql) - if err != nil { - return err - } - if r.IsEmpty() { - tasks = nil - return nil - } - if err := r.Structs(&tasks); err != nil { - return err - } - // 更新为 running - now := time.Now() - for _, t := range tasks { - // tx.Model 不走 gfdb Hook,这里手动更新必要字段 - _, err = tx.Exec( - fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask), - now, now, t.Id, - ) - if err != nil { - return err - } - } - return nil - }) - return -} - -// ListExpiredSuccess 获取已成功且过期的任务 -func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) { - if limit <= 0 { - limit = 100 - } - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - Where(entity.AsynchTaskCol.State, 2). - Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL"). - Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()). - Limit(limit). - All() - if err != nil { - return nil, err - } - err = r.Structs(&list) - return -} - -// ListTimeoutTasks 获取超时的排队/执行中任务 -func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) { - if limit <= 0 { - limit = 100 - } - deadline := gtime.New(time.Now().Add(-timeout)) - r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask). - WhereIn(entity.AsynchTaskCol.State, []int{0, 1}). - Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline). - Limit(limit). - All() - if err != nil { - return nil, err - } - err = r.Structs(&list) - return -} - -// DebugPing 用于启动时检测数据库连通性(可选) -func (d *taskDao) DebugPing(ctx context.Context) error { - _, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1") - return err -} diff --git a/dao/task_dao_bg.go b/dao/task_dao_bg.go index 500c69f..fc0a2fe 100644 --- a/dao/task_dao_bg.go +++ b/dao/task_dao_bg.go @@ -150,14 +150,6 @@ func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFi return err } -func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error { - _, err := gfdb.DB(ctx).Exec(ctx, - fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask), - taskID, - ) - return err -} - func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error { _, err := gfdb.DB(ctx).Exec(ctx, fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask), diff --git a/main.go b/main.go index d956e49..61c0cdb 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "model-gateway/model/dto" "os" "os/signal" "syscall" @@ -61,7 +62,10 @@ func startAutoRunner(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - if _, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines); err != nil { + if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{ + BatchSize: batchSize, + Goroutines: goroutines, + }); err != nil { g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err) } } @@ -83,7 +87,7 @@ func startAutoRunner(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - service.Cleaner.RunOnce(ctx) + _, _ = service.Cleaner.RunOnce(ctx) } } }() diff --git a/model/dto/model_dto.go b/model/dto/model_dto.go index 2b2ed66..92503fb 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_dto.go @@ -17,12 +17,14 @@ type CreateModelReq struct { Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"` IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"` Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"` RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"` - TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"` QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"` TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"` @@ -50,11 +52,13 @@ type UpdateModelReq struct { RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"` ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"` ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"` - TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射(可选更新)"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"` QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"` TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"` @@ -65,10 +69,18 @@ type UpdateModelReq struct { Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"` } +type UpdateModelRes struct { + ID int64 `json:"id,string" dc:"配置ID"` +} + // DeleteModelReq 删除模型配置 type DeleteModelReq struct { g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"` - ID string `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` + ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"` +} + +type DeleteModelRes struct { + ID int64 `json:"id,string" dc:"配置ID"` } // GetModelReq 获取模型配置详情 @@ -128,7 +140,14 @@ type UpdateChatModelReq struct { g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"` Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"` } +type UpdateChatModelRes struct { + ID int64 `json:"id,string" dc:"模型ID"` +} type GetIsChatModelReq struct { g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"` } + +type GetIsChatModelRes struct { + Model any `json:"model" dc:"模型详情"` +} diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index 6d713ae..b1d655e 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -4,58 +4,62 @@ import "gitea.com/red-future/common/beans" type asynchModelCol struct { beans.SQLBaseCol - ModelName string - ModelType string - BaseURL string - HttpMethod string - HeadMsg string - FormJSON string - RequestMapping string - ResponseMapping string - ResponseBody string - TokenMapping string - Prompt string - IsPrivate string - IsChatModel string - ApiKey string - Enabled string - MaxConcurrency string - QueueLimit string - TimeoutSeconds string - ExpectedSeconds string - RetryTimes string - RetryQueueMaxSecs string - AutoCleanSeconds string - Remark string - IsOwner string + ModelName string + ModelType string + BaseURL string + HttpMethod string + HeadMsg string + FormJSON string + RequestMapping string + ResponseMapping string + ResponseBody string + ResponseTokenField string + Prompt string + IsPrivate string + IsChatModel string + ApiKey string + Enabled string + MaxConcurrency string + QueueLimit string + TimeoutSeconds string + ExpectedSeconds string + RetryTimes string + RetryQueueMaxSecs string + AutoCleanSeconds string + Remark string + IsOwner string + OperatorName string + TokenConfig string } var AsynchModelCol = asynchModelCol{ - SQLBaseCol: beans.DefSQLBaseCol, - ModelName: "model_name", - ModelType: "model_type", - BaseURL: "base_url", - HttpMethod: "http_method", - HeadMsg: "head_msg", - FormJSON: "form_json", - RequestMapping: "request_mapping", - ResponseMapping: "response_mapping", - ResponseBody: "response_body", - TokenMapping: "token_mapping", - Prompt: "prompt", - IsPrivate: "is_private", - IsChatModel: "is_chat_model", - ApiKey: "api_key", - Enabled: "enabled", - MaxConcurrency: "max_concurrency", - QueueLimit: "queue_limit", - TimeoutSeconds: "timeout_seconds", - ExpectedSeconds: "expected_seconds", - RetryTimes: "retry_times", - RetryQueueMaxSecs: "retry_queue_max_seconds", - AutoCleanSeconds: "auto_clean_seconds", - Remark: "remark", - IsOwner: "is_owner", + SQLBaseCol: beans.DefSQLBaseCol, + ModelName: "model_name", + ModelType: "model_type", + BaseURL: "base_url", + HttpMethod: "http_method", + HeadMsg: "head_msg", + FormJSON: "form_json", + RequestMapping: "request_mapping", + ResponseMapping: "response_mapping", + ResponseBody: "response_body", + ResponseTokenField: "response_token_field", + Prompt: "prompt", + IsPrivate: "is_private", + IsChatModel: "is_chat_model", + ApiKey: "api_key", + Enabled: "enabled", + MaxConcurrency: "max_concurrency", + QueueLimit: "queue_limit", + TimeoutSeconds: "timeout_seconds", + ExpectedSeconds: "expected_seconds", + RetryTimes: "retry_times", + RetryQueueMaxSecs: "retry_queue_max_seconds", + AutoCleanSeconds: "auto_clean_seconds", + Remark: "remark", + IsOwner: "is_owner", + OperatorName: "operator_name", + TokenConfig: "token_config", } // AsynchModel 异步模型配置 @@ -70,7 +74,7 @@ type AsynchModel struct { RequestMapping any `orm:"request_mapping" json:"requestMapping"` ResponseMapping any `orm:"response_mapping" json:"responseMapping"` ResponseBody any `orm:"response_body" json:"responseBody"` - TokenMapping string `orm:"token_mapping" json:"tokenMapping"` + ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` Prompt string `orm:"prompt" json:"prompt"` IsPrivate *int `orm:"is_private" json:"isPrivate"` IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` @@ -84,5 +88,7 @@ type AsynchModel struct { RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` Remark string `orm:"remark" json:"remark"` - IsOwner *int `json:"isOwner" orm:"is_owner"` // 1=当前用户创建的,0=超级管理员的 + IsOwner *int `json:"isOwner" orm:"is_owner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig any `orm:"token_config" json:"tokenConfig"` } diff --git a/model/entity/asynch_model_type.go b/model/entity/asynch_model_type.go deleted file mode 100644 index 3a4d47a..0000000 --- a/model/entity/asynch_model_type.go +++ /dev/null @@ -1,26 +0,0 @@ -package entity - -import "gitea.com/red-future/common/beans" - -type asynchModelTypeCol struct { - beans.SQLBaseCol - TypeID string - TypeName string - Remark string -} - -var AsynchModelTypeCol = asynchModelTypeCol{ - SQLBaseCol: beans.DefSQLBaseCol, - TypeID: "type_id", - TypeName: "type_name", - Remark: "remark", -} - -// AsynchModelType 模型类型(图片/音频/视频等) -type AsynchModelType struct { - beans.SQLBaseDO `orm:",inline"` - TypeID int `orm:"type_id" json:"typeId"` - TypeName string `orm:"type_name" json:"type"` - Remark string `orm:"remark" json:"remark"` -} - diff --git a/service/auto_tune.go b/service/auto_tune.go index c36321c..481b910 100644 --- a/service/auto_tune.go +++ b/service/auto_tune.go @@ -2,8 +2,10 @@ package service import ( "context" + "errors" "fmt" "math" + "model-gateway/model/dto" "model-gateway/consts/public" "model-gateway/model/entity" @@ -34,9 +36,12 @@ type AutoTuneResult struct { // - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap) // - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap) // - 单次调整幅度限制 ±50%,写入 Redis(带 TTL) -func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) { - if windowSeconds <= 0 { - windowSeconds = 3600 +func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) { + if req == nil { + return nil, errors.New("request cannot be nil") + } + if req.WindowSeconds <= 0 { + req.WindowSeconds = 3600 // 默认1小时 } // 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限) var modelRows []*entity.AsynchModel @@ -68,7 +73,7 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) } } if len(modelMap) == 0 { - return []AutoTuneResult{}, nil + return nil, errors.New("no models found") } // 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时 @@ -89,7 +94,7 @@ SELECT model_name, AND finished_at IS NOT NULL AND finished_at >= (NOW() - (? || ' seconds')::interval) GROUP BY model_name`, public.TableNameTask) - r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds) + r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds) if err != nil { return nil, err } @@ -189,6 +194,8 @@ SELECT model_name, }) } - g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds) - return out, nil + g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds) + return &dto.AutoTuneRes{ + List: out, + }, nil } diff --git a/service/callback.go b/service/callback.go deleted file mode 100644 index bb7bee8..0000000 --- a/service/callback.go +++ /dev/null @@ -1,67 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - - "model-gateway/model/entity" - - "gitea.com/red-future/common/http" - "github.com/gogf/gf/v2/frame/g" -) - -// triggerCallback 任务成功后的回调: -// - JSON body 参数:task_id/state/oss_file/file_type/text(可选) -func triggerCallback(ctx context.Context, t *entity.AsynchTask) { - callbackURL := t.BizName + t.CallbackURL - headers := forwardHeaders(ctx) - var req struct{} - payload := map[string]interface{}{ - "task_id": t.TaskID, - "state": t.State, - "oss_file": t.OssFile, - "file_type": t.FileType, - "text": t.TextResult, - "error_msg": t.ErrorMsg, - } - jsonData, err := json.Marshal(payload) - if err != nil { - g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err) - return - } - g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节", - t.TaskID, callbackURL, len(headers), len(jsonData)) - - err = http.Post(ctx, callbackURL, headers, &req, jsonData) - if err != nil { - g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err) - return - } - g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData)) -} - -// triggerPromptsCallback 任务成功后的提示词回调 -// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息) -func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { - callbackURL := "prompts-core/session/sessionCallback" - headers := forwardHeaders(ctx) - var req struct{} - payload := map[string]interface{}{ - "epicycleId": epicycleId, - "text": t.TextResult, - } - jsonData, err := json.Marshal(payload) - if err != nil { - g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err) - return - } - g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节", - t.EpicycleId, callbackURL, len(headers), len(jsonData)) - - err = http.Post(ctx, callbackURL, headers, &req, jsonData) - if err != nil { - g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err) - return - } - g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData)) -} diff --git a/service/cleaner.go b/service/cleaner.go index 0aea6eb..ad80d54 100644 --- a/service/cleaner.go +++ b/service/cleaner.go @@ -2,6 +2,8 @@ package service import ( "context" + "model-gateway/model/dto" + "os" "time" "model-gateway/dao" @@ -14,14 +16,14 @@ var Cleaner = &cleaner{} type cleaner struct{} // RunOnce 由上层定时任务触发:执行一次清理/重试 -func (c *cleaner) RunOnce(ctx context.Context) { +func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) { // 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS) expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200) if err != nil { g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err) } else { for _, t := range expired { - deleteTmpResult(t.TmpFile) + _ = os.Remove(t.TmpFile) _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) } g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired)) @@ -82,11 +84,14 @@ func (c *cleaner) RunOnce(ctx context.Context) { g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err) } else { for _, t := range exhausted { - deleteTmpResult(t.TmpFile) + _ = os.Remove(t.TmpFile) // 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等) ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) } g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted)) } + return &dto.CleanWorkRes{ + Ok: true, + }, nil } diff --git a/service/file_detect.go b/service/file_detect.go index 1ec85fb..6d43c33 100644 --- a/service/file_detect.go +++ b/service/file_detect.go @@ -1,47 +1 @@ package service - -import ( - "net/http" - "strings" -) - -// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定) -func DetectFileType(data []byte) (contentType string, ext string) { - if len(data) == 0 { - return "application/octet-stream", "" - } - ct := http.DetectContentType(data) - // gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8 - if idx := strings.Index(ct, ";"); idx > 0 { - ct = strings.TrimSpace(ct[:idx]) - } - switch ct { - case "audio/mpeg": - return ct, ".mp3" - case "audio/wave", "audio/wav", "audio/x-wav": - return ct, ".wav" - case "video/mp4": - return ct, ".mp4" - case "image/png": - return ct, ".png" - case "image/jpeg": - return ct, ".jpg" - case "application/pdf": - return ct, ".pdf" - case "text/plain": - return ct, ".txt" - case "application/json": - return ct, ".json" - default: - // 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json) - if parts := strings.Split(ct, "/"); len(parts) == 2 { - sub := parts[1] - // 避免出现 "plain; charset=utf-8" 之类的后缀 - if idx := strings.Index(sub, ";"); idx > 0 { - sub = strings.TrimSpace(sub[:idx]) - } - return ct, "." + sub - } - return ct, "" - } -} diff --git a/service/gateway/gateway_http_service.go b/service/gateway/gateway_http_service.go new file mode 100644 index 0000000..608f8c4 --- /dev/null +++ b/service/gateway/gateway_http_service.go @@ -0,0 +1,171 @@ +package gateway + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "mime/multipart" + "model-gateway/common/util" + "model-gateway/model/entity" + "time" + + commonHttp "gitea.com/red-future/common/http" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/guid" +) + +type uploadFileResponse struct { + FileURL string `json:"fileURL"` // 文件 URL + FileSize int `json:"fileSize"` // 文件大小(字节) + FileName string `json:"fileName"` // 文件名 + FileFormat string `json:"fileFormat"` // 文件格式 + FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀 +} + +func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) { + // multipart + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + ext := fileExt + if ext == "" { + ext = ".bin" + } + if ext[0] != '.' { + ext = "." + ext + } + + filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext) + part, err := writer.CreateFormFile("file", filename) + if err != nil { + return "", err + } + if _, err = part.Write(data); err != nil { + return "", err + } + headers := util.ForwardHeaders(ctx) + fullURL := "oss/file/uploadFile" + g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) + + var resp uploadFileResponse + if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil { + return "", err + } + g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat) + return resp.FileURL, nil +} + +// TriggerCallback 任务成功后的回调: +// - JSON body 参数:task_id/state/oss_file/file_type/text(可选) +func TriggerCallback(ctx context.Context, t *entity.AsynchTask) { + headers := util.ForwardHeaders(ctx) + var req struct{} + payload := map[string]interface{}{ + "task_id": t.TaskID, + "state": t.State, + "oss_file": t.OssFile, + "file_type": t.FileType, + "text": t.TextResult, + "error_msg": t.ErrorMsg, + } + jsonData, err := json.Marshal(payload) + if err != nil { + g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err) + return + } + g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节", + t.TaskID, t.CallbackURL, len(headers), len(jsonData)) + + err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData) + if err != nil { + g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err) + return + } + g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData)) +} + +// TriggerPromptsCallback 任务成功后的提示词回调 +// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息) +func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { + callbackURL := "prompts-core/session/sessionCallback" + headers := util.ForwardHeaders(ctx) + var req struct{} + payload := map[string]interface{}{ + "epicycleId": epicycleId, + "text": t.TextResult, + } + jsonData, err := json.Marshal(payload) + if err != nil { + g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err) + return + } + g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节", + t.EpicycleId, callbackURL, len(headers), len(jsonData)) + + err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData) + if err != nil { + g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err) + return + } + g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData)) +} + +// IsSuperAdmin 调用admin-go服务检查是否是超级管理员 +func IsSuperAdmin(ctx context.Context) (res bool, err error) { + headers := util.ForwardHeaders(ctx) + var r = make(map[string]bool) + if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil { + return false, err + } + return r["isSuperAdmin"], err +} + +//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致) +//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) { +// if callbackURL == "" { +// return +// } +// +// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID) +// if task == nil { +// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID) +// return +// } +// +// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID) +// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList)) +// for i := range detailList { +// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i])) +// } +// +// // 构建与查询接口一致的 taskInfo +// taskInfo := dao.EntityToItem(task) +// +// // 兼容历史数据: 从 result 中补全 scenes 等字段 +// detailItems = enrichDetailsFromResult(task.Result, detailItems) +// +// payload := dto.CallbackPayload{ +// TaskInfo: taskInfo, +// DetailList: detailItems, +// } +// +// body, _ := json.Marshal(payload) +// +// // 透传调用方的用户信息 +// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1}) +// +// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body)) +// req.Header.Set("Content-Type", "application/json") +// req.Header.Set("X-User-Info", string(userJSON)) +// +// resp, reqErr := http.DefaultClient.Do(req) +// if reqErr != nil { +// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr) +// return +// } +// defer resp.Body.Close() +// +// respBody, _ := io.ReadAll(resp.Body) +// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody)) +//} diff --git a/service/headers.go b/service/headers.go deleted file mode 100644 index e83ae4d..0000000 --- a/service/headers.go +++ /dev/null @@ -1,53 +0,0 @@ -package service - -import ( - "context" - - "gitea.com/red-future/common/utils" - "github.com/gogf/gf/v2/frame/g" -) - -// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。 -// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。 -func asyncCtx(ctx context.Context) context.Context { - asyncCtx := context.WithoutCancel(ctx) - if r := g.RequestFromCtx(ctx); r != nil { - if token := r.Header.Get("Authorization"); token != "" { - asyncCtx = context.WithValue(asyncCtx, "token", token) - } - if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { - asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo) - } - } - if user, err := utils.GetUserInfo(ctx); err == nil && user != nil { - asyncCtx = context.WithValue(asyncCtx, "user", user) - } - return asyncCtx -} - -// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。 -func forwardHeaders(ctx context.Context) map[string]string { - headers := make(map[string]string) - - if token, ok := ctx.Value("token").(string); ok && token != "" { - headers["Authorization"] = token - } - if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" { - headers["X-User-Info"] = x - } - - // 兜底:从请求头拿 - if r := g.RequestFromCtx(ctx); r != nil { - if headers["Authorization"] == "" { - if token := r.Header.Get("Authorization"); token != "" { - headers["Authorization"] = token - } - } - if headers["X-User-Info"] == "" { - if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { - headers["X-User-Info"] = userInfo - } - } - } - return headers -} diff --git a/service/model_service.go b/service/model_service.go index 7eaa562..5981022 100644 --- a/service/model_service.go +++ b/service/model_service.go @@ -3,13 +3,15 @@ package service import ( "context" "errors" + "model-gateway/common/util" + "model-gateway/consts/public" "model-gateway/dao" "model-gateway/model/dto" "model-gateway/model/entity" + "model-gateway/service/gateway" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/db/gfdb" - "gitea.com/red-future/common/http" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" @@ -20,28 +22,20 @@ var Model = &modelService{} type modelService struct{} -// IsSuperAdmin 调用admin-go服务检查是否是超级管理员 -func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) { - headers := forwardHeaders(ctx) - var r = make(map[string]bool) - if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil { - return false, err - } - return r["isSuperAdmin"], err -} - func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { // 获取当前会话模型 if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 { var model *entity.AsynchModel - model, err = dao.Model.GetByIsChatModel(ctx) + model, err = dao.Model.Get(ctx, &entity.AsynchModel{ + IsChatModel: new(1), + }) if err != nil { return nil, err } // 如果有会话模型,那就改变为 0 if model != nil { - _, err = dao.Model.Update(ctx, &dto.UpdateModelReq{ - ID: model.Id, + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, IsChatModel: gconv.PtrInt(0), }) if err != nil { @@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res } req.IsOwner = gconv.PtrInt(1) - admin, err := s.IsSuperAdmin(ctx) + admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return } if admin { req.IsOwner = gconv.PtrInt(0) } - id, err := dao.Model.Insert(ctx, req) + id, err := dao.Model.Insert(ctx, &entity.AsynchModel{ + ModelName: req.ModelName, + ModelType: req.ModelType, + BaseURL: req.BaseURL, + HttpMethod: req.HttpMethod, + HeadMsg: req.HeadMsg, + Form: req.Form, + RequestMapping: req.RequestMapping, + ResponseMapping: req.ResponseMapping, + ResponseBody: req.ResponseBody, + ResponseTokenField: req.ResponseTokenField, + IsPrivate: req.IsPrivate, + IsChatModel: req.IsChatModel, + ApiKey: req.ApiKey, + Enabled: req.Enabled, + MaxConcurrency: req.MaxConcurrency, + QueueLimit: req.QueueLimit, + TimeoutSeconds: req.TimeoutSeconds, + ExpectedSeconds: req.ExpectedSeconds, + RetryTimes: req.RetryTimes, + RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, + AutoCleanSeconds: req.AutoCleanSeconds, + Remark: req.Remark, + IsOwner: req.IsOwner, + OperatorName: req.OperatorName, + TokenConfig: req.TokenConfig, + }) if err != nil { return nil, err } @@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro //根据当前 isChatModel 来判断是否更新模型 if req.IsChatModel == gconv.PtrInt(1) { //判断当前用户是否有会话模型 - model, err := dao.Model.GetByIsChatModel(ctx) + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + IsChatModel: new(1), + }) if err != nil { return err } @@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro } req.IsOwner = gconv.PtrInt(1) - admin, err := s.IsSuperAdmin(ctx) + admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return err } if admin { req.IsOwner = gconv.PtrInt(0) - _, err = dao.Model.Update(ctx, req) + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + ModelName: req.ModelName, + ModelType: req.ModelType, + BaseURL: req.BaseURL, + HttpMethod: req.HttpMethod, + HeadMsg: req.HeadMsg, + Form: req.Form, + RequestMapping: req.RequestMapping, + ResponseMapping: req.ResponseMapping, + ResponseBody: req.ResponseBody, + ResponseTokenField: req.ResponseTokenField, + IsPrivate: req.IsPrivate, + IsChatModel: req.IsChatModel, + ApiKey: req.ApiKey, + Enabled: req.Enabled, + MaxConcurrency: req.MaxConcurrency, + QueueLimit: req.QueueLimit, + TimeoutSeconds: req.TimeoutSeconds, + ExpectedSeconds: req.ExpectedSeconds, + RetryTimes: req.RetryTimes, + RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, + AutoCleanSeconds: req.AutoCleanSeconds, + Remark: req.Remark, + IsOwner: req.IsOwner, + OperatorName: req.OperatorName, + TokenConfig: req.TokenConfig, + }) if err != nil { return err } return nil } - - var user *beans.User - user, err = utils.GetUserInfo(ctx) - if err != nil { - return err - } // 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新 - var count int - count, err = dao.Model.Count(ctx, &dto.GetModelReq{ - ID: req.ID, - Creator: user.UserName, + model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) if err != nil { return err } - if count == 0 { + if model.TenantId == 1 { insertDto := new(dto.CreateModelReq) err = gconv.Struct(req, insertDto) if err != nil { return err } - _, err = dao.Model.Insert(ctx, insertDto) + _, err = dao.Model.Insert(ctx, &entity.AsynchModel{ + ModelName: req.ModelName, + ModelType: req.ModelType, + BaseURL: req.BaseURL, + HttpMethod: req.HttpMethod, + HeadMsg: req.HeadMsg, + Form: req.Form, + RequestMapping: req.RequestMapping, + ResponseMapping: req.ResponseMapping, + ResponseBody: req.ResponseBody, + ResponseTokenField: req.ResponseTokenField, + IsPrivate: req.IsPrivate, + IsChatModel: req.IsChatModel, + ApiKey: req.ApiKey, + Enabled: req.Enabled, + MaxConcurrency: req.MaxConcurrency, + QueueLimit: req.QueueLimit, + TimeoutSeconds: req.TimeoutSeconds, + ExpectedSeconds: req.ExpectedSeconds, + RetryTimes: req.RetryTimes, + RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, + AutoCleanSeconds: req.AutoCleanSeconds, + Remark: req.Remark, + IsOwner: req.IsOwner, + OperatorName: req.OperatorName, + TokenConfig: req.TokenConfig, + }) return err } - _, err = dao.Model.Update(ctx, req) + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + ModelName: req.ModelName, + ModelType: req.ModelType, + BaseURL: req.BaseURL, + HttpMethod: req.HttpMethod, + HeadMsg: req.HeadMsg, + Form: req.Form, + RequestMapping: req.RequestMapping, + ResponseMapping: req.ResponseMapping, + ResponseBody: req.ResponseBody, + ResponseTokenField: req.ResponseTokenField, + IsPrivate: req.IsPrivate, + IsChatModel: req.IsChatModel, + ApiKey: req.ApiKey, + Enabled: req.Enabled, + MaxConcurrency: req.MaxConcurrency, + QueueLimit: req.QueueLimit, + TimeoutSeconds: req.TimeoutSeconds, + ExpectedSeconds: req.ExpectedSeconds, + RetryTimes: req.RetryTimes, + RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, + AutoCleanSeconds: req.AutoCleanSeconds, + Remark: req.Remark, + IsOwner: req.IsOwner, + OperatorName: req.OperatorName, + TokenConfig: req.TokenConfig, + }) return err } -func (s *modelService) Delete(ctx context.Context, id string) error { - _, err := dao.Model.DeleteByID(ctx, id) +func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { + _, err := dao.Model.Delete(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + }) return err } -func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) { - model, err := dao.Model.Get(ctx, id) +func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) { + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + }) if err != nil { return nil, err } - model.Form = ParseJSONField(model.Form) - model.RequestMapping = ParseJSONField(model.RequestMapping) - model.ResponseMapping = ParseJSONField(model.ResponseMapping) - model.ResponseBody = ParseJSONField(model.ResponseBody) - return model, nil + model.Form = util.ParseJSONField(model.Form) + model.RequestMapping = util.ParseJSONField(model.RequestMapping) + model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) + model.ResponseBody = util.ParseJSONField(model.ResponseBody) + return &dto.GetModelRes{ + Model: model, + }, nil } -func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { +func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { var models []*entity.AsynchModel req.IsOwner = gconv.PtrInt(1) - admin, err := s.IsSuperAdmin(ctx) + admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return } @@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list [] var user *beans.User user, err = utils.GetUserInfo(ctx) if err != nil { - return nil, 0, err + return nil, err } req.Creator = user.UserName - models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req) + models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) if err != nil { return } // 处理列表中每条记录的 JSONB 字段 for _, m := range models { - m.Form = ParseJSONField(m.Form) - m.RequestMapping = ParseJSONField(m.RequestMapping) - m.ResponseMapping = ParseJSONField(m.ResponseMapping) - m.ResponseBody = ParseJSONField(m.ResponseBody) + m.Form = util.ParseJSONField(m.Form) + m.RequestMapping = util.ParseJSONField(m.RequestMapping) + m.ResponseMapping = util.ParseJSONField(m.ResponseMapping) + m.ResponseBody = util.ParseJSONField(m.ResponseBody) } - return models, total, nil + return &dto.ListModelRes{ + List: models, + Total: total, + }, nil } // GetModelTypesFromConfig 从配置文件读取模型类型 -func GetModelTypesFromConfig(ctx context.Context) map[int]string { - typeMap := make(map[int]string) - - // 读取配置 - configMap := g.Cfg().MustGet(ctx, "modelType.types").Map() - for k, v := range configMap { - typeID := gconv.Int(k) - typeName := gconv.String(v) - if typeID > 0 && typeName != "" { - typeMap[typeID] = typeName - } +func GetModelTypesFromConfig() (res *dto.TypeItem, err error) { + // 返回副本,避免外部修改 + types := make(map[int]string, len(public.ModelTypeName)) + for k, v := range public.ModelTypeName { + types[k] = v } - // 如果配置为空,使用默认值 - if len(typeMap) == 0 { - typeMap = map[int]string{ - 1: "推理模型", - 2: "图片模型", - 3: "音频模型", - 4: "向量化模型", - 5: "全模态模型", - } - } - return typeMap + return &dto.TypeItem{ + Type: types, + }, nil } func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { // 校验新会话模型是否存在 - newModel, err := dao.Model.Get(ctx, req.Id) + newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, + }) if err != nil { return err } if newModel == nil { return errors.New("新会话模型不存在") } - // 获取当前用户会话模型 - currentModel, err := dao.Model.GetByIsChatModel(ctx) + currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + IsChatModel: new(1), + }) if err != nil { return err } @@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM // 如果点击的就是当前会话模型(已经是1),取消它(设为0) if currentModel.Id != req.Id { - _, err = dao.Model.Update(ctx, &dto.UpdateModelReq{ - ID: currentModel.Id, + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, IsChatModel: gconv.PtrInt(0), }) if err != nil { @@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM } // 设置当前为会话模型(设为1) - _, err = dao.Model.Update(ctx, &dto.UpdateModelReq{ - ID: req.Id, + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, IsChatModel: gconv.PtrInt(1), }) return err @@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM return err } -func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) { - model, err := dao.Model.GetByIsChatModel(ctx) +func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) { + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + IsChatModel: new(1), + }) if err != nil { return nil, err } if model == nil { return nil, nil } - model.Form = ParseJSONField(model.Form) - model.RequestMapping = ParseJSONField(model.RequestMapping) - model.ResponseMapping = ParseJSONField(model.ResponseMapping) - model.ResponseBody = ParseJSONField(model.ResponseBody) - return model, nil + model.Form = util.ParseJSONField(model.Form) + model.RequestMapping = util.ParseJSONField(model.RequestMapping) + model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) + model.ResponseBody = util.ParseJSONField(model.ResponseBody) + return &dto.GetIsChatModelRes{ + Model: model, + }, nil } diff --git a/service/payload.go b/service/payload.go deleted file mode 100644 index b6873fc..0000000 --- a/service/payload.go +++ /dev/null @@ -1,25 +0,0 @@ -package service - -import "github.com/gogf/gf/v2/util/gconv" - -// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers -// 入库格式:{"payload": , "headers": {"Authorization": "...", "X-User-Info":"..."}} -func parseStoredPayload(v any) (payload any, headers map[string]string) { - if v == nil { - return nil, nil - } - m := gconv.Map(v) - if len(m) == 0 { - return v, nil - } - if h, ok := m["headers"]; ok { - headers = gconv.MapStrStr(h) - } - if p, ok := m["payload"]; ok { - payload = p - } else { - payload = v - } - return -} - diff --git a/service/storage.go b/service/storage.go deleted file mode 100644 index 759767a..0000000 --- a/service/storage.go +++ /dev/null @@ -1,18 +0,0 @@ -package service - -import ( - "context" - "errors" - - "model-gateway/model/entity" -) - -// StorageService 结果存储(OSS/MinIO)抽象 -type StorageService interface { - UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error) -} - -// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO) -var Storage StorageService = &ossStorage{} - -var ErrStorageNotConfigured = errors.New("存储未配置") diff --git a/service/storage_oss.go b/service/storage_oss.go deleted file mode 100644 index a938d4a..0000000 --- a/service/storage_oss.go +++ /dev/null @@ -1,81 +0,0 @@ -package service - -import ( - "bytes" - "context" - "fmt" - "mime/multipart" - "time" - - "model-gateway/model/entity" - - commonHttp "gitea.com/red-future/common/http" - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/util/gconv" - "github.com/gogf/gf/v2/util/guid" -) - -// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data) -type ossStorage struct{} - -type uploadFileResponse struct { - FileURL string `json:"fileURL"` // 文件 URL - FileSize int `json:"fileSize"` // 文件大小(字节) - FileName string `json:"fileName"` // 文件名 - FileFormat string `json:"fileFormat"` // 文件格式 - FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀 -} - -func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) { - // multipart - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - ext := fileExt - if ext == "" { - ext = ".bin" - } - if ext[0] != '.' { - ext = "." + ext - } - - filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext) - part, err := writer.CreateFormFile("file", filename) - if err != nil { - return "", err - } - if _, err := part.Write(data); err != nil { - return "", err - } - contentType := writer.FormDataContentType() - if err := writer.Close(); err != nil { - return "", err - } - - headers := forwardHeaders(ctx) - headers["Content-Type"] = contentType - - fullURL := "oss/file/uploadFile" - g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data)) - - var resp uploadFileResponse - if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil { - return "", err - } - g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat) - return resp.FileURL, nil -} - -// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用 -func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context { - if headers == nil { - return ctx - } - if v := gconv.String(headers["Authorization"]); v != "" { - ctx = context.WithValue(ctx, "token", v) - } - if v := gconv.String(headers["X-User-Info"]); v != "" { - ctx = context.WithValue(ctx, "xUserInfo", v) - } - return ctx -} diff --git a/service/task_service.go b/service/task_service.go index 91941f7..2a604cd 100644 --- a/service/task_service.go +++ b/service/task_service.go @@ -3,7 +3,7 @@ package service import ( "context" "errors" - "fmt" + "model-gateway/common/util" "time" "model-gateway/dao" @@ -21,13 +21,13 @@ var Task = &taskService{} type taskService struct{} func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { - fmt.Printf("打印请求:%+v", req) startAt := time.Now() // 固化 token/user 等信息 - ctx = asyncCtx(ctx) - + ctx = util.AsyncCtx(ctx) // 1) 检查模型配置 - m, err := dao.Model.GetByModelName(ctx, req.ModelName) + m, err := dao.Model.Get(ctx, &entity.AsynchModel{ + ModelName: req.ModelName, + }) if err != nil { return nil, err } @@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * // 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用 storedPayload := map[string]any{ "payload": req.RequestPayload, - "headers": forwardHeaders(ctx), + "headers": util.ForwardHeaders(ctx), } t := &entity.AsynchTask{ @@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, defer ticker.Stop() tryRun := func() bool { - t, err := dao.Task.GetByTaskID(ctx, taskID) + t, err := dao.Task.Get(ctx, &entity.AsynchTask{ + TaskID: taskID, + }) if err != nil { g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err) return true @@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, } switch t.State { case 0: - if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil { + if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil { g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err) } else { g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID) @@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, } func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { - t, err := dao.Task.GetByTaskID(ctx, taskID) + t, err := dao.Task.Get(ctx, &entity.AsynchTask{ + TaskID: taskID, + }) if err != nil { return nil, err } @@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r continue } // 按模型配置决定保留时间 - m, err := dao.Model.GetByModelName(ctx, t.ModelName) + m, err := dao.Model.Get(ctx, &entity.AsynchModel{ + ModelName: t.ModelName, + }) if err != nil { return nil, err } diff --git a/service/tmp_store.go b/service/tmp_store.go deleted file mode 100644 index 9dea56a..0000000 --- a/service/tmp_store.go +++ /dev/null @@ -1,38 +0,0 @@ -package service - -import ( - "fmt" - "os" - "path/filepath" -) - -// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 -func saveTmpResult(taskID string, data []byte, ext string) (string, error) { - dir := filepath.Join(os.TempDir(), "model-asynch") - if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err - } - if ext == "" { - ext = ".bin" - } - if ext[0] != '.' { - ext = "." + ext - } - path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) - if err := os.WriteFile(path, data, 0o644); err != nil { - return "", err - } - return path, nil -} - -func loadTmpResult(path string) ([]byte, error) { - return os.ReadFile(path) -} - -func deleteTmpResult(path string) { - if path == "" { - return - } - _ = os.Remove(path) -} - diff --git a/service/utils.go b/service/utils.go deleted file mode 100644 index 72e4367..0000000 --- a/service/utils.go +++ /dev/null @@ -1,113 +0,0 @@ -package service - -import ( - "encoding/json" - "strings" - - "github.com/gogf/gf/v2/container/gvar" -) - -func normalizeFormValue(v any) any { - // 目标:对外永远返回 JSON 数组/对象,而不是字符串。 - if v == nil { - return []any{} - } - switch t := v.(type) { - case string: - s := strings.TrimSpace(t) - if s == "" { - return []any{} - } - return normalizeFormValueFromJSONString(s) - case []byte: - if len(t) == 0 { - return []any{} - } - return normalizeFormValueFromJSONBytes(t) - case *gvar.Var: - // goframe 常见的 DB 返回类型 - if t == nil { - return []any{} - } - b := t.Bytes() - if len(b) > 0 { - return normalizeFormValueFromJSONBytes(b) - } - s := strings.TrimSpace(t.String()) - if s == "" { - return []any{} - } - return normalizeFormValueFromJSONString(s) - default: - // 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型) - if vb, ok := v.(interface{ Bytes() []byte }); ok { - if b := vb.Bytes(); len(b) > 0 { - return normalizeFormValueFromJSONBytes(b) - } - } - if vs, ok := v.(interface{ String() string }); ok { - if s := strings.TrimSpace(vs.String()); s != "" { - return normalizeFormValueFromJSONString(s) - } - } - // 已经是 []any / map[string]any 等结构 - return v - } -} - -// 兼容“JSONB 里存了 JSON 字符串”的历史数据: -// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象) -func normalizeFormValueFromJSONString(s string) any { - var out any - if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil { - return []any{} - } - // 如果解出来还是 string,且看起来是 JSON,再解一层 - if inner, ok := out.(string); ok { - inner = strings.TrimSpace(inner) - if inner == "" { - return []any{} - } - if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") { - var out2 any - if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil { - return out2 - } - } - return []any{} - } - return out -} - -func normalizeFormValueFromJSONBytes(b []byte) any { - var out any - if err := json.Unmarshal(b, &out); err != nil || out == nil { - return []any{} - } - // bytes 解出来也可能是 string(同上) - if inner, ok := out.(string); ok { - return normalizeFormValueFromJSONString(inner) - } - return out -} - -func ParseJSONField(field any) any { - var v *gvar.Var - switch val := field.(type) { - case *gvar.Var: - v = val - default: - return field - } - - if v == nil || v.IsNil() || v.IsEmpty() { - return nil - } - - str := v.String() - var result any - if json.Unmarshal([]byte(str), &result) == nil { - return result - } - return str -} diff --git a/service/worker.go b/service/worker.go index cff9aab..51779ba 100644 --- a/service/worker.go +++ b/service/worker.go @@ -2,7 +2,13 @@ package service import ( "context" + "errors" "fmt" + "model-gateway/common/util" + "model-gateway/model/dto" + "model-gateway/service/gateway" + "os" + "path/filepath" "strings" "time" "unicode/utf8" @@ -23,24 +29,23 @@ type asyncWorker struct { // RunOnce 由上层定时任务触发:一次性抢占并处理一批任务 // - batchSize: 本次抢占数量 // - goroutines: 本次并发数(协程池大小) -func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) { - if batchSize <= 0 { - batchSize = 10 +func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { + if req.BatchSize <= 0 { + req.BatchSize = 10 } - if goroutines <= 0 { - goroutines = 1 + if req.Goroutines <= 0 { + req.Goroutines = 1 } - tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize) + tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize) if err != nil { - return 0, err + return nil, err } if len(tasks) == 0 { - return 0, nil + return nil, errors.New("no task to run") } - pool := grpool.New(goroutines) + pool := grpool.New(req.Goroutines) defer pool.Close() - - claimed = len(tasks) + claimed := len(tasks) done := make(chan struct{}, claimed) for _, t := range tasks { task := t @@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c for i := 0; i < claimed; i++ { <-done } - return claimed, nil + return &dto.RunWorkRes{ + Claimed: claimed, + }, nil } // RunByTaskID 创建任务后立即异步尝试执行当前任务: @@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) { // 从任务入库的 request_payload 里恢复 payload + headers - payload, headers := parseStoredPayload(t.RequestPayload) + payload, headers := util.ParseStoredPayload(t.RequestPayload) if len(headers) > 0 { - ctx = setTaskHeadersToCtx(ctx, headers) + ctx = util.SetTaskHeadersToCtx(ctx, headers) } // 1) 拉取模型配置 @@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() - go triggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) // ================================ return } @@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = errMsg - go triggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) // ================================ return } @@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() - go triggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) // ================================ return } @@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy // phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载 if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" { - data, err = loadTmpResult(t.TmpFile) + data, err = os.ReadFile(t.TmpFile) if err == nil && len(data) > 0 { - contentType, ext = DetectFileType(data) + contentType, ext = util.DetectFileType(data) } else { data = nil } @@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy // ============ 失败回调 ============ t.State = 3 t.ErrorMsg = err.Error() - go triggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) // ================================ return } - contentType, ext = DetectFileType(data) + contentType, ext = util.DetectFileType(data) if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } @@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy } // 4) 存储 OSS - ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType) + ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType) if err != nil { // OSS 阶段失败:保留临时文件,下一轮仅重试 OSS _ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error()) @@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy if fileType == "" { fileType = contentType } - if err := dao.Task.UpdateSuccessGlobal( + if err = dao.Task.UpdateSuccessGlobal( ctx, t.Id, ossURL, @@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy textResult, int64(len(data)), nil, - GetExpendTokens(m.TokenMapping, textResult), + GetExpendTokens(m.ResponseTokenField, textResult), ); err != nil { g.Log().Errorf(ctx, "[worker] update success failed: %v", err) return @@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy t.FileType = fileType t.TextResult = textResult g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL) - go triggerCallback(context.WithoutCancel(ctx), t) + go gateway.TriggerCallback(context.WithoutCancel(ctx), t) // ============ 如果有 epicycleId,也触发业务回调 ============ if epicycleId != 0 { - go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId) + go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId) } // 成功后清理临时文件 - deleteTmpResult(t.TmpFile) + _ = os.Remove(t.TmpFile) +} + +// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。 +func saveTmpResult(taskID string, data []byte, ext string) (string, error) { + dir := filepath.Join(os.TempDir(), "model-asynch") + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + if ext == "" { + ext = ".bin" + } + if ext[0] != '.' { + ext = "." + ext + } + path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext)) + if err := os.WriteFile(path, data, 0o644); err != nil { + return "", err + } + return path, nil } func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error { @@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int { value := gjson.Get(textResult, tokenMapping) if value.Exists() { return int(value.Int()) - } else { - return len(textResult) } + return len(textResult) } diff --git a/timezone/Shanghai b/timezone/Shanghai deleted file mode 100644 index 91f6f8b..0000000 Binary files a/timezone/Shanghai and /dev/null differ diff --git a/timezone/localtime b/timezone/localtime deleted file mode 100644 index 91f6f8b..0000000 Binary files a/timezone/localtime and /dev/null differ diff --git a/update.sql b/update.sql index 4dc2578..c916a95 100644 --- a/update.sql +++ b/update.sql @@ -40,9 +40,18 @@ CREATE TABLE IF NOT EXISTS asynch_models ( retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾) auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理) remark TEXT DEFAULT '' -- 备注 - token_mapping VARCHAR(128) NOT NULL DEFAULT ''; -- token 映射 -); - + response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射 + operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称 + token_config JSONB NOT NULL DEFAULT '{ + "zh_ratio": 1.0, + "en_ratio": 1.3, + "space_ratio": 0.1, + "punctuation_ratio": 0.1, + "max_window_size": 8192, + "reserve_ratio": 0.2, + "min_reserve": 512, +}'::jsonb -- Token配置 + ); CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL; CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name); CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id); @@ -83,8 +92,17 @@ COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数'; COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)'; COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)'; COMMENT ON COLUMN asynch_models.remark IS '备注'; -COMMENT ON COLUMN asynch_models.token_mapping IS 'token映射'; - +COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射'; +COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称'; +COMMENT ON COLUMN asynch_models.token_config IS '{ + "zh_ratio": 1.0, // 中文字符→token系数 + "en_ratio": 1.3, // 英文单词→token系数 + "space_ratio": 0.1, // 空格系数 + "punctuation_ratio": 0.1, // 标点系数 + "max_window_size": 8192, // 模型最大窗口 + "reserve_ratio": 0.2, // 预留回复空间比例 + "min_reserve": 512, // 最少预留token数 +}'; -- =========================