refactor(model): 重构模型实体和数据访问层
This commit is contained in:
69
common/util/files.go
Normal file
69
common/util/files.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
100
common/util/headers.go
Normal file
100
common/util/headers.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||
func AsyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
|
||||
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
SetHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||
SetHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||
FallbackToRequestHeaders(headers, ctx)
|
||||
return headers
|
||||
}
|
||||
|
||||
// SetHeaderFromContext 从上下文中设置 header
|
||||
func SetHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||
headers[headerKey] = value
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackToRequestHeaders 从请求头中获取作为兜底
|
||||
func FallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
// ParseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func ParseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
28
common/util/json.go
Normal file
28
common/util/json.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
29
config.yml
29
config.yml
@@ -26,6 +26,27 @@ database:
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
model_gateway:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "model-gateway"
|
||||
prefix: ""
|
||||
role: "master"
|
||||
debug: true
|
||||
dryRun: false
|
||||
charset: "utf8"
|
||||
timezone: "Asia/Shanghai"
|
||||
maxIdle: 5
|
||||
maxOpen: 20
|
||||
maxLifetime: "30s"
|
||||
maxIdleConnTime: "30s"
|
||||
createdAt: "created_at"
|
||||
updatedAt: "updated_at"
|
||||
deletedAt: "deleted_at"
|
||||
timeMaintainDisabled: false
|
||||
|
||||
redis:
|
||||
default:
|
||||
@@ -48,11 +69,3 @@ asynch:
|
||||
cleaner:
|
||||
enabled: false
|
||||
intervalSeconds: 30
|
||||
|
||||
modelType:
|
||||
types:
|
||||
1: "推理模型"
|
||||
2: "图片模型"
|
||||
3: "音频模型"
|
||||
4: "向量化模型"
|
||||
5: "全模态模型"
|
||||
|
||||
19
consts/public/public.go
Normal file
19
consts/public/public.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package public
|
||||
|
||||
// ModelType 模型类型常量
|
||||
const (
|
||||
ModelTypeInference = 1 // 推理模型
|
||||
ModelTypeImage = 2 // 图片模型
|
||||
ModelTypeAudio = 3 // 音频模型
|
||||
ModelTypeVector = 4 // 向量化模型
|
||||
ModelTypeOmni = 5 // 全模态模型
|
||||
)
|
||||
|
||||
// ModelTypeName 模型类型名称映射
|
||||
var ModelTypeName = map[int]string{
|
||||
ModelTypeInference: "推理模型",
|
||||
ModelTypeImage: "图片模型",
|
||||
ModelTypeAudio: "音频模型",
|
||||
ModelTypeVector: "向量化模型",
|
||||
ModelTypeOmni: "全模态模型",
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
package public
|
||||
|
||||
const (
|
||||
DbNameModelGateway = "model_gateway" //数据库名称
|
||||
)
|
||||
|
||||
const (
|
||||
TableNameModel = "asynch_models" // 模型表
|
||||
TableNameTask = "asynch_task" // 任务表
|
||||
|
||||
@@ -4,10 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type model struct{}
|
||||
@@ -21,67 +18,44 @@ func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *
|
||||
}
|
||||
|
||||
// UpdateModel 更改配置
|
||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
||||
err = service.Model.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteModel 删除配置
|
||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Model.Delete(ctx, req.ID)
|
||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
||||
err = service.Model.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetModel 获取配置详情(按 modelName)
|
||||
// GetModel 获取配置详情
|
||||
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
||||
model, err := service.Model.Get(ctx, req.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &dto.GetModelRes{Model: model}, nil
|
||||
return service.Model.Get(ctx, req)
|
||||
}
|
||||
|
||||
// ListModel 配置列表
|
||||
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
list, total, err := service.Model.List(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListModelRes{
|
||||
List: list,
|
||||
Total: total,
|
||||
}, nil
|
||||
return service.Model.List(ctx, req)
|
||||
}
|
||||
|
||||
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
||||
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
windowSeconds := 3600
|
||||
if req != nil && req.WindowSeconds > 0 {
|
||||
windowSeconds = req.WindowSeconds
|
||||
}
|
||||
list, err := service.AutoTune(ctx, windowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.AutoTuneRes{List: list}, nil
|
||||
return service.AutoTune(ctx, req)
|
||||
}
|
||||
|
||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res dto.TypeItem, err error) {
|
||||
modelType := service.GetModelTypesFromConfig(ctx)
|
||||
res.Type = modelType
|
||||
return res, nil
|
||||
// ListType 模型类型列表
|
||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
||||
return service.GetModelTypesFromConfig()
|
||||
}
|
||||
|
||||
// UpdateChatModel 更新是否为聊天模型
|
||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *beans.ResponseEmpty, err error) {
|
||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
||||
err = service.Model.UpdateChatModel(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetIsChatModel 获取是否为聊天模型
|
||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *entity.AsynchModel, err error) {
|
||||
// GetIsChatModel 获取当前会话模型
|
||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
||||
return service.Model.GetIsChatModel(ctx)
|
||||
}
|
||||
|
||||
@@ -34,24 +34,10 @@ func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.Lis
|
||||
|
||||
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
||||
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
batchSize, goroutines := 10, 1
|
||||
if req != nil {
|
||||
if req.BatchSize > 0 {
|
||||
batchSize = req.BatchSize
|
||||
}
|
||||
if req.Goroutines > 0 {
|
||||
goroutines = req.Goroutines
|
||||
}
|
||||
}
|
||||
n, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.RunWorkRes{Claimed: n}, nil
|
||||
return service.AsyncWorker.RunOnce(ctx, req)
|
||||
}
|
||||
|
||||
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
||||
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
||||
service.Cleaner.RunOnce(ctx)
|
||||
return &dto.CleanWorkRes{Ok: true}, nil
|
||||
return service.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
|
||||
190
dao/model_dao.go
190
dao/model_dao.go
@@ -2,14 +2,11 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
@@ -18,157 +15,80 @@ var Model = &modelDao{}
|
||||
|
||||
type modelDao struct{}
|
||||
|
||||
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
|
||||
asyncModel := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &asyncModel)
|
||||
// Insert 插入
|
||||
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
|
||||
m := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
// Update 更新
|
||||
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, m.ID).
|
||||
Data(m).
|
||||
Data(&req).
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
// Delete 删除
|
||||
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.ModelName, modelName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
Where(entity.AsynchModelCol.Id, id).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
|
||||
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
|
||||
// Get 按ID获取(带租户隔离,只查当前租户)
|
||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.Id, req.ID).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt)
|
||||
if modelNameLike != "" {
|
||||
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
if modelType != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
|
||||
}
|
||||
if isPrivate != 0 {
|
||||
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
|
||||
}
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
|
||||
userInfo, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
||||
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
|
||||
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
// 构建 Where 条件
|
||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
|
||||
args := []any{creator}
|
||||
|
||||
if modelNameLike != "" {
|
||||
whereSQL += " AND model_name LIKE ?"
|
||||
args = append(args, "%"+modelNameLike+"%")
|
||||
}
|
||||
|
||||
// 查总数
|
||||
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
|
||||
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
|
||||
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return
|
||||
}
|
||||
if len(countResult) > 0 {
|
||||
total = gconv.Int64(countResult[0]["count"])
|
||||
}
|
||||
|
||||
// 查列表
|
||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
offset := (pageNum - 1) * pageSize
|
||||
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
||||
}
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
err = r.Structs(&list)
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
// 基础 SQL
|
||||
sql := `
|
||||
@@ -212,7 +132,7 @@ WHERE deleted_at IS NULL
|
||||
// 最后拼接排序
|
||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -226,14 +146,24 @@ WHERE deleted_at IS NULL
|
||||
return
|
||||
}
|
||||
|
||||
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
|
||||
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
||||
OrderDesc(entity.AsynchModelCol.CreatedAt).
|
||||
All()
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.AsynchModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
)
|
||||
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.AsynchModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(list) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
@@ -7,14 +7,22 @@ import (
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
type opLogDao struct{}
|
||||
|
||||
var OpLog = &opLogDao{}
|
||||
|
||||
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
|
||||
// Insert 插入
|
||||
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
|
||||
m := new(entity.LogsModelOp)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name)
|
||||
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
||||
public.TableNameStat, public.TableNameStat,
|
||||
)
|
||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
196
dao/task_dao.go
196
dao/task_dao.go
@@ -2,9 +2,6 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
@@ -18,40 +15,47 @@ var Task = &taskDao{}
|
||||
|
||||
type taskDao struct{}
|
||||
|
||||
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
|
||||
// Insert 插入
|
||||
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
|
||||
m := new(entity.AsynchTask)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||||
One()
|
||||
// Get 获取
|
||||
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
err = r.Struct(&t)
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
|
||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
err = r.Structs(&m)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,7 +66,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Where(entity.AsynchTaskCol.State, 2).
|
||||
Data(data).
|
||||
@@ -70,73 +74,6 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 1,
|
||||
entity.AsynchTaskCol.StartedAt: now,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 2,
|
||||
entity.AsynchTaskCol.OssFile: ossFile,
|
||||
entity.AsynchTaskCol.FileType: fileType,
|
||||
entity.AsynchTaskCol.FileSize: fileSize,
|
||||
entity.AsynchTaskCol.ErrorMsg: "",
|
||||
entity.AsynchTaskCol.FinishedAt: now,
|
||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
|
||||
now := gtime.Now()
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 3,
|
||||
entity.AsynchTaskCol.ErrorMsg: errorMsg,
|
||||
entity.AsynchTaskCol.FinishedAt: now,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
|
||||
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
|
||||
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.ModelName, modelName).
|
||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||||
Count()
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
||||
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||||
@@ -161,90 +98,3 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1)
|
||||
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
|
||||
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
|
||||
FROM %s
|
||||
WHERE deleted_at IS NULL AND state = 0
|
||||
ORDER BY created_at ASC
|
||||
LIMIT %d
|
||||
FOR UPDATE SKIP LOCKED`,
|
||||
public.TableNameTask,
|
||||
batchSize,
|
||||
)
|
||||
r, err := tx.GetAll(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
tasks = nil
|
||||
return nil
|
||||
}
|
||||
if err := r.Structs(&tasks); err != nil {
|
||||
return err
|
||||
}
|
||||
// 更新为 running
|
||||
now := time.Now()
|
||||
for _, t := range tasks {
|
||||
// tx.Model 不走 gfdb Hook,这里手动更新必要字段
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, t.Id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// ListExpiredSuccess 获取已成功且过期的任务
|
||||
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.State, 2).
|
||||
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
|
||||
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
|
||||
Limit(limit).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// ListTimeoutTasks 获取超时的排队/执行中任务
|
||||
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
deadline := gtime.New(time.Now().Add(-timeout))
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
||||
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
|
||||
Limit(limit).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// DebugPing 用于启动时检测数据库连通性(可选)
|
||||
func (d *taskDao) DebugPing(ctx context.Context) error {
|
||||
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -150,14 +150,6 @@ func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFi
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
|
||||
taskID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||
|
||||
8
main.go
8
main.go
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
@@ -61,7 +62,10 @@ func startAutoRunner(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines); err != nil {
|
||||
if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
||||
BatchSize: batchSize,
|
||||
Goroutines: goroutines,
|
||||
}); err != nil {
|
||||
g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -83,7 +87,7 @@ func startAutoRunner(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
service.Cleaner.RunOnce(ctx)
|
||||
_, _ = service.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -17,12 +17,14 @@ type CreateModelReq struct {
|
||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||
Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||
@@ -50,11 +52,13 @@ type UpdateModelReq struct {
|
||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射(可选更新)"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||
@@ -65,10 +69,18 @@ type UpdateModelReq struct {
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||
}
|
||||
|
||||
type UpdateModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// DeleteModelReq 删除模型配置
|
||||
type DeleteModelReq struct {
|
||||
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
|
||||
ID string `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||
}
|
||||
|
||||
type DeleteModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"配置ID"`
|
||||
}
|
||||
|
||||
// GetModelReq 获取模型配置详情
|
||||
@@ -128,7 +140,14 @@ type UpdateChatModelReq struct {
|
||||
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
|
||||
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
|
||||
}
|
||||
type UpdateChatModelRes struct {
|
||||
ID int64 `json:"id,string" dc:"模型ID"`
|
||||
}
|
||||
|
||||
type GetIsChatModelReq struct {
|
||||
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
|
||||
}
|
||||
|
||||
type GetIsChatModelRes struct {
|
||||
Model any `json:"model" dc:"模型详情"`
|
||||
}
|
||||
|
||||
@@ -4,58 +4,62 @@ import "gitea.com/red-future/common/beans"
|
||||
|
||||
type asynchModelCol struct {
|
||||
beans.SQLBaseCol
|
||||
ModelName string
|
||||
ModelType string
|
||||
BaseURL string
|
||||
HttpMethod string
|
||||
HeadMsg string
|
||||
FormJSON string
|
||||
RequestMapping string
|
||||
ResponseMapping string
|
||||
ResponseBody string
|
||||
TokenMapping string
|
||||
Prompt string
|
||||
IsPrivate string
|
||||
IsChatModel string
|
||||
ApiKey string
|
||||
Enabled string
|
||||
MaxConcurrency string
|
||||
QueueLimit string
|
||||
TimeoutSeconds string
|
||||
ExpectedSeconds string
|
||||
RetryTimes string
|
||||
RetryQueueMaxSecs string
|
||||
AutoCleanSeconds string
|
||||
Remark string
|
||||
IsOwner string
|
||||
ModelName string
|
||||
ModelType string
|
||||
BaseURL string
|
||||
HttpMethod string
|
||||
HeadMsg string
|
||||
FormJSON string
|
||||
RequestMapping string
|
||||
ResponseMapping string
|
||||
ResponseBody string
|
||||
ResponseTokenField string
|
||||
Prompt string
|
||||
IsPrivate string
|
||||
IsChatModel string
|
||||
ApiKey string
|
||||
Enabled string
|
||||
MaxConcurrency string
|
||||
QueueLimit string
|
||||
TimeoutSeconds string
|
||||
ExpectedSeconds string
|
||||
RetryTimes string
|
||||
RetryQueueMaxSecs string
|
||||
AutoCleanSeconds string
|
||||
Remark string
|
||||
IsOwner string
|
||||
OperatorName string
|
||||
TokenConfig string
|
||||
}
|
||||
|
||||
var AsynchModelCol = asynchModelCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelName: "model_name",
|
||||
ModelType: "model_type",
|
||||
BaseURL: "base_url",
|
||||
HttpMethod: "http_method",
|
||||
HeadMsg: "head_msg",
|
||||
FormJSON: "form_json",
|
||||
RequestMapping: "request_mapping",
|
||||
ResponseMapping: "response_mapping",
|
||||
ResponseBody: "response_body",
|
||||
TokenMapping: "token_mapping",
|
||||
Prompt: "prompt",
|
||||
IsPrivate: "is_private",
|
||||
IsChatModel: "is_chat_model",
|
||||
ApiKey: "api_key",
|
||||
Enabled: "enabled",
|
||||
MaxConcurrency: "max_concurrency",
|
||||
QueueLimit: "queue_limit",
|
||||
TimeoutSeconds: "timeout_seconds",
|
||||
ExpectedSeconds: "expected_seconds",
|
||||
RetryTimes: "retry_times",
|
||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
Remark: "remark",
|
||||
IsOwner: "is_owner",
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
ModelName: "model_name",
|
||||
ModelType: "model_type",
|
||||
BaseURL: "base_url",
|
||||
HttpMethod: "http_method",
|
||||
HeadMsg: "head_msg",
|
||||
FormJSON: "form_json",
|
||||
RequestMapping: "request_mapping",
|
||||
ResponseMapping: "response_mapping",
|
||||
ResponseBody: "response_body",
|
||||
ResponseTokenField: "response_token_field",
|
||||
Prompt: "prompt",
|
||||
IsPrivate: "is_private",
|
||||
IsChatModel: "is_chat_model",
|
||||
ApiKey: "api_key",
|
||||
Enabled: "enabled",
|
||||
MaxConcurrency: "max_concurrency",
|
||||
QueueLimit: "queue_limit",
|
||||
TimeoutSeconds: "timeout_seconds",
|
||||
ExpectedSeconds: "expected_seconds",
|
||||
RetryTimes: "retry_times",
|
||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||
AutoCleanSeconds: "auto_clean_seconds",
|
||||
Remark: "remark",
|
||||
IsOwner: "is_owner",
|
||||
OperatorName: "operator_name",
|
||||
TokenConfig: "token_config",
|
||||
}
|
||||
|
||||
// AsynchModel 异步模型配置
|
||||
@@ -70,7 +74,7 @@ type AsynchModel struct {
|
||||
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
@@ -84,5 +88,7 @@ type AsynchModel struct {
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"` // 1=当前用户创建的,0=超级管理员的
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig any `orm:"token_config" json:"tokenConfig"`
|
||||
}
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
package entity
|
||||
|
||||
import "gitea.com/red-future/common/beans"
|
||||
|
||||
type asynchModelTypeCol struct {
|
||||
beans.SQLBaseCol
|
||||
TypeID string
|
||||
TypeName string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var AsynchModelTypeCol = asynchModelTypeCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TypeID: "type_id",
|
||||
TypeName: "type_name",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
// AsynchModelType 模型类型(图片/音频/视频等)
|
||||
type AsynchModelType struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
TypeID int `orm:"type_id" json:"typeId"`
|
||||
TypeName string `orm:"type_name" json:"type"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"model-gateway/model/dto"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
@@ -34,9 +36,12 @@ type AutoTuneResult struct {
|
||||
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
||||
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
||||
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
||||
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 3600
|
||||
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request cannot be nil")
|
||||
}
|
||||
if req.WindowSeconds <= 0 {
|
||||
req.WindowSeconds = 3600 // 默认1小时
|
||||
}
|
||||
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
||||
var modelRows []*entity.AsynchModel
|
||||
@@ -68,7 +73,7 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
|
||||
}
|
||||
}
|
||||
if len(modelMap) == 0 {
|
||||
return []AutoTuneResult{}, nil
|
||||
return nil, errors.New("no models found")
|
||||
}
|
||||
|
||||
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
||||
@@ -89,7 +94,7 @@ SELECT model_name,
|
||||
AND finished_at IS NOT NULL
|
||||
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
||||
GROUP BY model_name`, public.TableNameTask)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,6 +194,8 @@ SELECT model_name,
|
||||
})
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
|
||||
return out, nil
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
|
||||
return &dto.AutoTuneRes{
|
||||
List: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// triggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
callbackURL := t.BizName + t.CallbackURL
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// triggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -14,14 +16,14 @@ var Cleaner = &cleaner{}
|
||||
type cleaner struct{}
|
||||
|
||||
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
||||
func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
@@ -82,11 +84,14 @@ func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
}
|
||||
return &dto.CleanWorkRes{
|
||||
Ok: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,47 +1 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
171
service/gateway/gateway_http_service.go
Normal file
171
service/gateway/gateway_http_service.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/entity"
|
||||
"time"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err = part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// TriggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
|
||||
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
|
||||
// if callbackURL == "" {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
|
||||
// if task == nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
|
||||
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
|
||||
// for i := range detailList {
|
||||
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
|
||||
// }
|
||||
//
|
||||
// // 构建与查询接口一致的 taskInfo
|
||||
// taskInfo := dao.EntityToItem(task)
|
||||
//
|
||||
// // 兼容历史数据: 从 result 中补全 scenes 等字段
|
||||
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
|
||||
//
|
||||
// payload := dto.CallbackPayload{
|
||||
// TaskInfo: taskInfo,
|
||||
// DetailList: detailItems,
|
||||
// }
|
||||
//
|
||||
// body, _ := json.Marshal(payload)
|
||||
//
|
||||
// // 透传调用方的用户信息
|
||||
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
|
||||
//
|
||||
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
|
||||
// req.Header.Set("Content-Type", "application/json")
|
||||
// req.Header.Set("X-User-Info", string(userJSON))
|
||||
//
|
||||
// resp, reqErr := http.DefaultClient.Do(req)
|
||||
// if reqErr != nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
|
||||
// return
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// respBody, _ := io.ReadAll(resp.Body)
|
||||
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
|
||||
//}
|
||||
@@ -1,53 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
@@ -3,13 +3,15 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -20,28 +22,20 @@ var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := forwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
// 获取当前会话模型
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
var model *entity.AsynchModel
|
||||
model, err = dao.Model.GetByIsChatModel(ctx)
|
||||
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 如果有会话模型,那就改变为 0
|
||||
if model != nil {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: model.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, req)
|
||||
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
//根据当前 isChatModel 来判断是否更新模型
|
||||
if req.IsChatModel == gconv.PtrInt(1) {
|
||||
//判断当前用户是否有会话模型
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||||
var count int
|
||||
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
|
||||
ID: req.ID,
|
||||
Creator: user.UserName,
|
||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
if model.TenantId == 1 {
|
||||
insertDto := new(dto.CreateModelReq)
|
||||
err = gconv.Struct(req, insertDto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Insert(ctx, insertDto)
|
||||
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.Get(ctx, id)
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
var models []*entity.AsynchModel
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理列表中每条记录的 JSONB 字段
|
||||
for _, m := range models {
|
||||
m.Form = ParseJSONField(m.Form)
|
||||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
||||
m.Form = util.ParseJSONField(m.Form)
|
||||
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
|
||||
}
|
||||
return models, total, nil
|
||||
return &dto.ListModelRes{
|
||||
List: models,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
||||
typeMap := make(map[int]string)
|
||||
|
||||
// 读取配置
|
||||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
||||
for k, v := range configMap {
|
||||
typeID := gconv.Int(k)
|
||||
typeName := gconv.String(v)
|
||||
if typeID > 0 && typeName != "" {
|
||||
typeMap[typeID] = typeName
|
||||
}
|
||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||
// 返回副本,避免外部修改
|
||||
types := make(map[int]string, len(public.ModelTypeName))
|
||||
for k, v := range public.ModelTypeName {
|
||||
types[k] = v
|
||||
}
|
||||
// 如果配置为空,使用默认值
|
||||
if len(typeMap) == 0 {
|
||||
typeMap = map[int]string{
|
||||
1: "推理模型",
|
||||
2: "图片模型",
|
||||
3: "音频模型",
|
||||
4: "向量化模型",
|
||||
5: "全模态模型",
|
||||
}
|
||||
}
|
||||
return typeMap
|
||||
return &dto.TypeItem{
|
||||
Type: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 校验新会话模型是否存在
|
||||
newModel, err := dao.Model.Get(ctx, req.Id)
|
||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 获取当前用户会话模型
|
||||
currentModel, err := dao.Model.GetByIsChatModel(ctx)
|
||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
|
||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: currentModel.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
}
|
||||
|
||||
// 设置当前为会话模型(设为1)
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: req.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
@@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
return &dto.GetIsChatModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package service
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
)
|
||||
|
||||
// StorageService 结果存储(OSS/MinIO)抽象
|
||||
type StorageService interface {
|
||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
||||
}
|
||||
|
||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
||||
var Storage StorageService = &ossStorage{}
|
||||
|
||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
||||
@@ -1,81 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
type ossStorage struct{}
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := forwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -21,13 +21,13 @@ var Task = &taskService{}
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
fmt.Printf("打印请求:%+v", req)
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
ctx = util.AsyncCtx(ctx)
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
"headers": util.ForwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
defer ticker.Stop()
|
||||
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
return true
|
||||
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
switch t.State {
|
||||
case 0:
|
||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: t.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func loadTmpResult(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func deleteTmpResult(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
|
||||
113
service/utils.go
113
service/utils.go
@@ -1,113 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
@@ -2,7 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/gateway"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -23,24 +29,23 @@ type asyncWorker struct {
|
||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||
// - batchSize: 本次抢占数量
|
||||
// - goroutines: 本次并发数(协程池大小)
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
if req.BatchSize <= 0 {
|
||||
req.BatchSize = 10
|
||||
}
|
||||
if goroutines <= 0 {
|
||||
goroutines = 1
|
||||
if req.Goroutines <= 0 {
|
||||
req.Goroutines = 1
|
||||
}
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return 0, nil
|
||||
return nil, errors.New("no task to run")
|
||||
}
|
||||
pool := grpool.New(goroutines)
|
||||
pool := grpool.New(req.Goroutines)
|
||||
defer pool.Close()
|
||||
|
||||
claimed = len(tasks)
|
||||
claimed := len(tasks)
|
||||
done := make(chan struct{}, claimed)
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
@@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
||||
for i := 0; i < claimed; i++ {
|
||||
<-done
|
||||
}
|
||||
return claimed, nil
|
||||
return &dto.RunWorkRes{
|
||||
Claimed: claimed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
@@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
payload, headers := util.ParseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
ctx = util.SetTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
@@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
data, err = os.ReadFile(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
} else {
|
||||
data = nil
|
||||
}
|
||||
@@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
@@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
}
|
||||
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
@@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(
|
||||
if err = dao.Task.UpdateSuccessGlobal(
|
||||
ctx,
|
||||
t.Id,
|
||||
ossURL,
|
||||
@@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
textResult,
|
||||
int64(len(data)),
|
||||
nil,
|
||||
GetExpendTokens(m.TokenMapping, textResult),
|
||||
GetExpendTokens(m.ResponseTokenField, textResult),
|
||||
); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
@@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
t.FileType = fileType
|
||||
t.TextResult = textResult
|
||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||
if epicycleId != 0 {
|
||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
}
|
||||
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
@@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
} else {
|
||||
return len(textResult)
|
||||
}
|
||||
return len(textResult)
|
||||
}
|
||||
|
||||
Binary file not shown.
Binary file not shown.
28
update.sql
28
update.sql
@@ -40,9 +40,18 @@ CREATE TABLE IF NOT EXISTS asynch_models (
|
||||
retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)
|
||||
auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理)
|
||||
remark TEXT DEFAULT '' -- 备注
|
||||
token_mapping VARCHAR(128) NOT NULL DEFAULT ''; -- token 映射
|
||||
);
|
||||
|
||||
response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射
|
||||
operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称
|
||||
token_config JSONB NOT NULL DEFAULT '{
|
||||
"zh_ratio": 1.0,
|
||||
"en_ratio": 1.3,
|
||||
"space_ratio": 0.1,
|
||||
"punctuation_ratio": 0.1,
|
||||
"max_window_size": 8192,
|
||||
"reserve_ratio": 0.2,
|
||||
"min_reserve": 512,
|
||||
}'::jsonb -- Token配置
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
|
||||
@@ -83,8 +92,17 @@ COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
|
||||
COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)';
|
||||
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)';
|
||||
COMMENT ON COLUMN asynch_models.remark IS '备注';
|
||||
COMMENT ON COLUMN asynch_models.token_mapping IS 'token映射';
|
||||
|
||||
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
|
||||
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称';
|
||||
COMMENT ON COLUMN asynch_models.token_config IS '{
|
||||
"zh_ratio": 1.0, // 中文字符→token系数
|
||||
"en_ratio": 1.3, // 英文单词→token系数
|
||||
"space_ratio": 0.1, // 空格系数
|
||||
"punctuation_ratio": 0.1, // 标点系数
|
||||
"max_window_size": 8192, // 模型最大窗口
|
||||
"reserve_ratio": 0.2, // 预留回复空间比例
|
||||
"min_reserve": 512, // 最少预留token数
|
||||
}';
|
||||
|
||||
|
||||
-- =========================
|
||||
|
||||
Reference in New Issue
Block a user