refactor(model): 重构模型实体和数据访问层
This commit is contained in:
69
common/util/files.go
Normal file
69
common/util/files.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||||
|
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return "application/octet-stream", ""
|
||||||
|
}
|
||||||
|
ct := http.DetectContentType(data)
|
||||||
|
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||||
|
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||||
|
ct = strings.TrimSpace(ct[:idx])
|
||||||
|
}
|
||||||
|
switch ct {
|
||||||
|
case "audio/mpeg":
|
||||||
|
return ct, ".mp3"
|
||||||
|
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||||
|
return ct, ".wav"
|
||||||
|
case "video/mp4":
|
||||||
|
return ct, ".mp4"
|
||||||
|
case "image/png":
|
||||||
|
return ct, ".png"
|
||||||
|
case "image/jpeg":
|
||||||
|
return ct, ".jpg"
|
||||||
|
case "application/pdf":
|
||||||
|
return ct, ".pdf"
|
||||||
|
case "text/plain":
|
||||||
|
return ct, ".txt"
|
||||||
|
case "application/json":
|
||||||
|
return ct, ".json"
|
||||||
|
default:
|
||||||
|
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||||
|
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||||
|
sub := parts[1]
|
||||||
|
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||||
|
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||||
|
sub = strings.TrimSpace(sub[:idx])
|
||||||
|
}
|
||||||
|
return ct, "." + sub
|
||||||
|
}
|
||||||
|
return ct, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||||
|
func SaveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||||
|
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if ext == "" {
|
||||||
|
ext = ".bin"
|
||||||
|
}
|
||||||
|
if ext[0] != '.' {
|
||||||
|
ext = "." + ext
|
||||||
|
}
|
||||||
|
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||||
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
100
common/util/headers.go
Normal file
100
common/util/headers.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AsyncCtx 固化异步上下文中的 token 和用户信息,避免请求结束后丢失
|
||||||
|
func AsyncCtx(ctx context.Context) context.Context {
|
||||||
|
asyncCtx := context.WithoutCancel(ctx)
|
||||||
|
|
||||||
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
|
if token := r.Header.Get("Authorization"); token != "" {
|
||||||
|
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||||
|
}
|
||||||
|
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||||
|
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||||
|
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||||
|
}
|
||||||
|
|
||||||
|
return asyncCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardHeaders 透传调用链路的头信息,优先使用 ctx 中的固化值
|
||||||
|
func ForwardHeaders(ctx context.Context) map[string]string {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
SetHeaderFromContext(headers, ctx, "Authorization", "token")
|
||||||
|
SetHeaderFromContext(headers, ctx, "X-User-Info", "xUserInfo")
|
||||||
|
FallbackToRequestHeaders(headers, ctx)
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHeaderFromContext 从上下文中设置 header
|
||||||
|
func SetHeaderFromContext(headers map[string]string, ctx context.Context, headerKey, ctxKey string) {
|
||||||
|
if value, ok := ctx.Value(ctxKey).(string); ok && value != "" {
|
||||||
|
headers[headerKey] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackToRequestHeaders 从请求头中获取作为兜底
|
||||||
|
func FallbackToRequestHeaders(headers map[string]string, ctx context.Context) {
|
||||||
|
r := g.RequestFromCtx(ctx)
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers["Authorization"] == "" {
|
||||||
|
if token := r.Header.Get("Authorization"); token != "" {
|
||||||
|
headers["Authorization"] = token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if headers["X-User-Info"] == "" {
|
||||||
|
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||||
|
headers["X-User-Info"] = userInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||||
|
func SetTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||||
|
if headers == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||||
|
ctx = context.WithValue(ctx, "token", v)
|
||||||
|
}
|
||||||
|
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||||
|
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||||
|
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||||
|
func ParseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||||
|
if v == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
m := gconv.Map(v)
|
||||||
|
if len(m) == 0 {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
if h, ok := m["headers"]; ok {
|
||||||
|
headers = gconv.MapStrStr(h)
|
||||||
|
}
|
||||||
|
if p, ok := m["payload"]; ok {
|
||||||
|
payload = p
|
||||||
|
} else {
|
||||||
|
payload = v
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
28
common/util/json.go
Normal file
28
common/util/json.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/container/gvar"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseJSONField(field any) any {
|
||||||
|
var v *gvar.Var
|
||||||
|
switch val := field.(type) {
|
||||||
|
case *gvar.Var:
|
||||||
|
v = val
|
||||||
|
default:
|
||||||
|
return field
|
||||||
|
}
|
||||||
|
|
||||||
|
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
str := v.String()
|
||||||
|
var result any
|
||||||
|
if json.Unmarshal([]byte(str), &result) == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
29
config.yml
29
config.yml
@@ -26,6 +26,27 @@ database:
|
|||||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||||
|
model_gateway:
|
||||||
|
- type: "pgsql"
|
||||||
|
host: "116.204.74.41"
|
||||||
|
port: "15432"
|
||||||
|
user: "postgres"
|
||||||
|
pass: "Bjang09@686^*^"
|
||||||
|
name: "model-gateway"
|
||||||
|
prefix: ""
|
||||||
|
role: "master"
|
||||||
|
debug: true
|
||||||
|
dryRun: false
|
||||||
|
charset: "utf8"
|
||||||
|
timezone: "Asia/Shanghai"
|
||||||
|
maxIdle: 5
|
||||||
|
maxOpen: 20
|
||||||
|
maxLifetime: "30s"
|
||||||
|
maxIdleConnTime: "30s"
|
||||||
|
createdAt: "created_at"
|
||||||
|
updatedAt: "updated_at"
|
||||||
|
deletedAt: "deleted_at"
|
||||||
|
timeMaintainDisabled: false
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
default:
|
default:
|
||||||
@@ -48,11 +69,3 @@ asynch:
|
|||||||
cleaner:
|
cleaner:
|
||||||
enabled: false
|
enabled: false
|
||||||
intervalSeconds: 30
|
intervalSeconds: 30
|
||||||
|
|
||||||
modelType:
|
|
||||||
types:
|
|
||||||
1: "推理模型"
|
|
||||||
2: "图片模型"
|
|
||||||
3: "音频模型"
|
|
||||||
4: "向量化模型"
|
|
||||||
5: "全模态模型"
|
|
||||||
|
|||||||
19
consts/public/public.go
Normal file
19
consts/public/public.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package public
|
||||||
|
|
||||||
|
// ModelType 模型类型常量
|
||||||
|
const (
|
||||||
|
ModelTypeInference = 1 // 推理模型
|
||||||
|
ModelTypeImage = 2 // 图片模型
|
||||||
|
ModelTypeAudio = 3 // 音频模型
|
||||||
|
ModelTypeVector = 4 // 向量化模型
|
||||||
|
ModelTypeOmni = 5 // 全模态模型
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelTypeName 模型类型名称映射
|
||||||
|
var ModelTypeName = map[int]string{
|
||||||
|
ModelTypeInference: "推理模型",
|
||||||
|
ModelTypeImage: "图片模型",
|
||||||
|
ModelTypeAudio: "音频模型",
|
||||||
|
ModelTypeVector: "向量化模型",
|
||||||
|
ModelTypeOmni: "全模态模型",
|
||||||
|
}
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
package public
|
package public
|
||||||
|
|
||||||
|
const (
|
||||||
|
DbNameModelGateway = "model_gateway" //数据库名称
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TableNameModel = "asynch_models" // 模型表
|
TableNameModel = "asynch_models" // 模型表
|
||||||
TableNameTask = "asynch_task" // 任务表
|
TableNameTask = "asynch_task" // 任务表
|
||||||
|
|||||||
@@ -4,10 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
|
||||||
"model-gateway/service"
|
"model-gateway/service"
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type model struct{}
|
type model struct{}
|
||||||
@@ -21,67 +18,44 @@ func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateModel 更改配置
|
// UpdateModel 更改配置
|
||||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) {
|
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
||||||
err = service.Model.Update(ctx, req)
|
err = service.Model.Update(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteModel 删除配置
|
// DeleteModel 删除配置
|
||||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) {
|
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
||||||
err = service.Model.Delete(ctx, req.ID)
|
err = service.Model.Delete(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModel 获取配置详情(按 modelName)
|
// GetModel 获取配置详情
|
||||||
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
||||||
model, err := service.Model.Get(ctx, req.ID)
|
return service.Model.Get(ctx, req)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if model == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return &dto.GetModelRes{Model: model}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListModel 配置列表
|
// ListModel 配置列表
|
||||||
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||||
list, total, err := service.Model.List(ctx, req)
|
return service.Model.List(ctx, req)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dto.ListModelRes{
|
|
||||||
List: list,
|
|
||||||
Total: total,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
||||||
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||||
windowSeconds := 3600
|
return service.AutoTune(ctx, req)
|
||||||
if req != nil && req.WindowSeconds > 0 {
|
|
||||||
windowSeconds = req.WindowSeconds
|
|
||||||
}
|
|
||||||
list, err := service.AutoTune(ctx, windowSeconds)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dto.AutoTuneRes{List: list}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res dto.TypeItem, err error) {
|
// ListType 模型类型列表
|
||||||
modelType := service.GetModelTypesFromConfig(ctx)
|
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
||||||
res.Type = modelType
|
return service.GetModelTypesFromConfig()
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChatModel 更新是否为聊天模型
|
// UpdateChatModel 更新是否为聊天模型
|
||||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *beans.ResponseEmpty, err error) {
|
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
||||||
err = service.Model.UpdateChatModel(ctx, req)
|
err = service.Model.UpdateChatModel(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIsChatModel 获取是否为聊天模型
|
// GetIsChatModel 获取当前会话模型
|
||||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *entity.AsynchModel, err error) {
|
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
||||||
return service.Model.GetIsChatModel(ctx)
|
return service.Model.GetIsChatModel(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,24 +34,10 @@ func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.Lis
|
|||||||
|
|
||||||
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
||||||
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||||
batchSize, goroutines := 10, 1
|
return service.AsyncWorker.RunOnce(ctx, req)
|
||||||
if req != nil {
|
|
||||||
if req.BatchSize > 0 {
|
|
||||||
batchSize = req.BatchSize
|
|
||||||
}
|
|
||||||
if req.Goroutines > 0 {
|
|
||||||
goroutines = req.Goroutines
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dto.RunWorkRes{Claimed: n}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
||||||
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
||||||
service.Cleaner.RunOnce(ctx)
|
return service.Cleaner.RunOnce(ctx)
|
||||||
return &dto.CleanWorkRes{Ok: true}, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
190
dao/model_dao.go
190
dao/model_dao.go
@@ -2,14 +2,11 @@ package dao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"gitea.com/red-future/common/utils"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
@@ -18,157 +15,80 @@ var Model = &modelDao{}
|
|||||||
|
|
||||||
type modelDao struct{}
|
type modelDao struct{}
|
||||||
|
|
||||||
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
|
// Insert 插入
|
||||||
asyncModel := new(entity.AsynchModel)
|
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
|
||||||
err = gconv.Struct(req, &asyncModel)
|
m := new(entity.AsynchModel)
|
||||||
|
err = gconv.Struct(req, &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).Data(asyncModel).Insert()
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
|
Insert(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return
|
||||||
}
|
}
|
||||||
return r.LastInsertId()
|
return r.LastInsertId()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *modelDao) Update(ctx context.Context, m *dto.UpdateModelReq) (rows int64, err error) {
|
// Update 更新
|
||||||
// 触发 gfdb 的 updateHook 自动填充 updater,需要显式带 updater 字段
|
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
OmitEmpty().
|
OmitEmpty().
|
||||||
Where(entity.AsynchModelCol.Id, m.ID).
|
Data(&req).
|
||||||
Data(m).
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
Update()
|
Update()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return
|
||||||
}
|
}
|
||||||
return r.RowsAffected()
|
return r.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *modelDao) DeleteByID(ctx context.Context, id string) (rows int64, err error) {
|
// Delete 删除
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||||
Where(entity.AsynchModelCol.Id, id).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
|
OmitEmpty().
|
||||||
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
Delete()
|
Delete()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return
|
||||||
}
|
}
|
||||||
return r.RowsAffected()
|
return r.RowsAffected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *modelDao) GetByModelName(ctx context.Context, modelName string) (m *entity.AsynchModel, err error) {
|
// Get 按ID获取(带租户隔离,只查当前租户)
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
Where(entity.AsynchModelCol.ModelName, modelName).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
One()
|
OmitEmpty().
|
||||||
if err != nil {
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
err = r.Struct(&m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *modelDao) Get(ctx context.Context, id int64) (m *entity.AsynchModel, err error) {
|
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
|
||||||
NoTenantId(ctx).
|
|
||||||
Where(entity.AsynchModelCol.Id, id).
|
|
||||||
One()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
err = r.Struct(&m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
|
|
||||||
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameModel).OmitEmpty().
|
|
||||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||||
Where(entity.AsynchModelCol.Id, req.ID).Count()
|
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||||
return
|
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||||
}
|
Fields(fields).One()
|
||||||
|
|
||||||
func (d *modelDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
|
|
||||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
|
||||||
OrderDesc(entity.AsynchModelCol.CreatedAt)
|
|
||||||
if modelNameLike != "" {
|
|
||||||
model = model.WhereLike(entity.AsynchModelCol.ModelName, "%"+modelNameLike+"%")
|
|
||||||
}
|
|
||||||
if modelType != 0 {
|
|
||||||
model = model.Where(entity.AsynchModelCol.ModelType, modelType)
|
|
||||||
}
|
|
||||||
if isPrivate != 0 {
|
|
||||||
model = model.Where(entity.AsynchModelCol.IsPrivate, isPrivate)
|
|
||||||
}
|
|
||||||
if pageNum > 0 && pageSize > 0 {
|
|
||||||
model = model.Page(pageNum, pageSize)
|
|
||||||
}
|
|
||||||
r, totalInt, err := model.AllAndCount(false)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return
|
||||||
}
|
|
||||||
total = gconv.Int64(totalInt)
|
|
||||||
err = r.Structs(&list)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *modelDao) GetByIsChatModel(ctx context.Context) (m *entity.AsynchModel, err error) {
|
|
||||||
userInfo, err := utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
|
||||||
Where(entity.AsynchModelCol.IsChatModel, 1).
|
|
||||||
Where(entity.AsynchModelCol.Creator, userInfo.UserName).
|
|
||||||
One()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
err = r.Struct(&m)
|
err = r.Struct(&m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListByCreatorAndPlatform 普通用户:平台公共(tenant_id=0) + 自己创建的(creator=xxx)
|
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||||
func (d *modelDao) ListByCreatorAndPlatform(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||||
// 构建 Where 条件
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||||
whereSQL := "deleted_at IS NULL AND (tenant_id = 1 OR creator = ?)" //1 代表超级管理员
|
NoTenantId(ctx).
|
||||||
args := []any{creator}
|
OmitEmpty().
|
||||||
|
Where(entity.AsynchModelCol.Id, req.Id).
|
||||||
if modelNameLike != "" {
|
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||||
whereSQL += " AND model_name LIKE ?"
|
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||||
args = append(args, "%"+modelNameLike+"%")
|
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||||
}
|
Fields(fields).One()
|
||||||
|
|
||||||
// 查总数
|
|
||||||
countSQL := fmt.Sprintf("SELECT COUNT(1) FROM %s WHERE %s", public.TableNameModel, whereSQL)
|
|
||||||
countResult, err := gfdb.DB(ctx).GetAll(ctx, countSQL, args...)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return
|
||||||
}
|
}
|
||||||
if len(countResult) > 0 {
|
err = r.Struct(&m)
|
||||||
total = gconv.Int64(countResult[0]["count"])
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查列表
|
|
||||||
querySQL := fmt.Sprintf("SELECT * FROM %s WHERE %s ORDER BY created_at DESC", public.TableNameModel, whereSQL)
|
|
||||||
if pageNum > 0 && pageSize > 0 {
|
|
||||||
offset := (pageNum - 1) * pageSize
|
|
||||||
querySQL += fmt.Sprintf(" LIMIT %d OFFSET %d", pageSize, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
r, err := gfdb.DB(ctx).GetAll(ctx, querySQL, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.Structs(&list)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||||
// 基础 SQL
|
// 基础 SQL
|
||||||
sql := `
|
sql := `
|
||||||
@@ -212,7 +132,7 @@ WHERE deleted_at IS NULL
|
|||||||
// 最后拼接排序
|
// 最后拼接排序
|
||||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||||
|
|
||||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, args...)
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
@@ -226,14 +146,24 @@ WHERE deleted_at IS NULL
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListAll 用于分组展示:查询全部模型(不按类型过滤,类型拆分在 service 层处理)
|
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||||
func (d *modelDao) ListAll(ctx context.Context) (list []*entity.AsynchModel, err error) {
|
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameModel).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||||
OrderDesc(entity.AsynchModelCol.CreatedAt).
|
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||||
All()
|
tenantId, modelName,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = r.Structs(&list)
|
if r.IsEmpty() {
|
||||||
return
|
return nil, nil
|
||||||
|
}
|
||||||
|
var list []*entity.AsynchModel
|
||||||
|
if err := r.Structs(&list); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(list) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return list[0], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
package dao
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"model-gateway/consts/public"
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
|
||||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
|
||||||
r, err := gfdb.DB(ctx).GetAll(ctx,
|
|
||||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
|
||||||
tenantId, modelName,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
var list []*entity.AsynchModel
|
|
||||||
if err := r.Structs(&list); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(list) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return list[0], nil
|
|
||||||
}
|
|
||||||
@@ -7,14 +7,22 @@ import (
|
|||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type opLogDao struct{}
|
type opLogDao struct{}
|
||||||
|
|
||||||
var OpLog = &opLogDao{}
|
var OpLog = &opLogDao{}
|
||||||
|
|
||||||
func (d *opLogDao) Insert(ctx context.Context, log *entity.LogsModelOp) (id int64, err error) {
|
// Insert 插入
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameOpLog).Data(log).Insert()
|
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
|
||||||
|
m := new(entity.LogsModelOp)
|
||||||
|
err = gconv.Struct(req, &m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
|
||||||
|
Insert(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ ON CONFLICT (day, tenant_id, creator, model_name)
|
|||||||
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
||||||
public.TableNameStat, public.TableNameStat,
|
public.TableNameStat, public.TableNameStat,
|
||||||
)
|
)
|
||||||
_, err := gfdb.DB(ctx).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
196
dao/task_dao.go
196
dao/task_dao.go
@@ -2,9 +2,6 @@ package dao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
|
||||||
@@ -18,40 +15,47 @@ var Task = &taskDao{}
|
|||||||
|
|
||||||
type taskDao struct{}
|
type taskDao struct{}
|
||||||
|
|
||||||
func (d *taskDao) Insert(ctx context.Context, t *entity.AsynchTask) (id int64, err error) {
|
// Insert 插入
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Data(t).Insert()
|
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
|
||||||
|
m := new(entity.AsynchTask)
|
||||||
|
err = gconv.Struct(req, &m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||||
|
Insert(m)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return r.LastInsertId()
|
return r.LastInsertId()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *taskDao) GetByTaskID(ctx context.Context, taskID string) (t *entity.AsynchTask, err error) {
|
// Get 获取
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
|
||||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||||
One()
|
OmitEmpty().
|
||||||
|
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||||
|
Fields(fields).One()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return
|
||||||
}
|
}
|
||||||
if r.IsEmpty() {
|
err = r.Struct(&m)
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
err = r.Struct(&t)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
||||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.AsynchTask, err error) {
|
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
|
||||||
if len(taskIDs) == 0 {
|
if len(taskIDs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||||
|
OmitEmpty().
|
||||||
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
||||||
All()
|
All()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = r.Structs(&list)
|
err = r.Structs(&m)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +66,7 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
|||||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||||
entity.AsynchTaskCol.Updater: "",
|
entity.AsynchTaskCol.Updater: "",
|
||||||
}
|
}
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||||
Where(entity.AsynchTaskCol.Id, id).
|
Where(entity.AsynchTaskCol.Id, id).
|
||||||
Where(entity.AsynchTaskCol.State, 2).
|
Where(entity.AsynchTaskCol.State, 2).
|
||||||
Data(data).
|
Data(data).
|
||||||
@@ -70,73 +74,6 @@ func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gt
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *taskDao) UpdateRunning(ctx context.Context, id int64) error {
|
|
||||||
now := gtime.Now()
|
|
||||||
data := gdb.Map{
|
|
||||||
entity.AsynchTaskCol.State: 1,
|
|
||||||
entity.AsynchTaskCol.StartedAt: now,
|
|
||||||
entity.AsynchTaskCol.Updater: "",
|
|
||||||
}
|
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.Id, id).
|
|
||||||
Data(data).
|
|
||||||
Update()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *taskDao) UpdateSuccess(ctx context.Context, id int64, ossFile, fileType string, fileSize int64, expireAt *gtime.Time) error {
|
|
||||||
now := gtime.Now()
|
|
||||||
data := gdb.Map{
|
|
||||||
entity.AsynchTaskCol.State: 2,
|
|
||||||
entity.AsynchTaskCol.OssFile: ossFile,
|
|
||||||
entity.AsynchTaskCol.FileType: fileType,
|
|
||||||
entity.AsynchTaskCol.FileSize: fileSize,
|
|
||||||
entity.AsynchTaskCol.ErrorMsg: "",
|
|
||||||
entity.AsynchTaskCol.FinishedAt: now,
|
|
||||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
|
||||||
entity.AsynchTaskCol.Updater: "",
|
|
||||||
}
|
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.Id, id).
|
|
||||||
Data(data).
|
|
||||||
Update()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *taskDao) UpdateFailed(ctx context.Context, id int64, errorMsg string) error {
|
|
||||||
now := gtime.Now()
|
|
||||||
data := gdb.Map{
|
|
||||||
entity.AsynchTaskCol.State: 3,
|
|
||||||
entity.AsynchTaskCol.ErrorMsg: errorMsg,
|
|
||||||
entity.AsynchTaskCol.FinishedAt: now,
|
|
||||||
entity.AsynchTaskCol.Updater: "",
|
|
||||||
}
|
|
||||||
_, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.Id, id).
|
|
||||||
Data(data).
|
|
||||||
Update()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *taskDao) SoftDeleteByTaskID(ctx context.Context, taskID string) (rows int64, err error) {
|
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.TaskID, taskID).
|
|
||||||
Delete()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return r.RowsAffected()
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountActiveByModel 统计某模型排队中/执行中的任务数,用于 queue_limit 限制(近似值)
|
|
||||||
func (d *taskDao) CountActiveByModel(ctx context.Context, modelName string) (int64, error) {
|
|
||||||
n, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.ModelName, modelName).
|
|
||||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
|
||||||
Count()
|
|
||||||
return int64(n), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
||||||
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
||||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
m := gfdb.DB(ctx).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||||||
@@ -161,90 +98,3 @@ func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike
|
|||||||
err = r.Structs(&list)
|
err = r.Structs(&list)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaimPending 抢占 pending 任务(state=0),并在同一事务中更新为 running(state=1)
|
|
||||||
// 使用 PostgreSQL: FOR UPDATE SKIP LOCKED 避免多 worker 重复消费
|
|
||||||
func (d *taskDao) ClaimPending(ctx context.Context, batchSize int) (tasks []*entity.AsynchTask, err error) {
|
|
||||||
if batchSize <= 0 {
|
|
||||||
batchSize = 1
|
|
||||||
}
|
|
||||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
|
||||||
sql := fmt.Sprintf(
|
|
||||||
`SELECT id, tenant_id, model_name, task_id, input_ref, request_payload
|
|
||||||
FROM %s
|
|
||||||
WHERE deleted_at IS NULL AND state = 0
|
|
||||||
ORDER BY created_at ASC
|
|
||||||
LIMIT %d
|
|
||||||
FOR UPDATE SKIP LOCKED`,
|
|
||||||
public.TableNameTask,
|
|
||||||
batchSize,
|
|
||||||
)
|
|
||||||
r, err := tx.GetAll(sql)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if r.IsEmpty() {
|
|
||||||
tasks = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := r.Structs(&tasks); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 更新为 running
|
|
||||||
now := time.Now()
|
|
||||||
for _, t := range tasks {
|
|
||||||
// tx.Model 不走 gfdb Hook,这里手动更新必要字段
|
|
||||||
_, err = tx.Exec(
|
|
||||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
|
||||||
now, now, t.Id,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListExpiredSuccess 获取已成功且过期的任务
|
|
||||||
func (d *taskDao) ListExpiredSuccess(ctx context.Context, limit int) (list []*entity.AsynchTask, err error) {
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 100
|
|
||||||
}
|
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
Where(entity.AsynchTaskCol.State, 2).
|
|
||||||
Where(entity.AsynchTaskCol.ExpireAt+" IS NOT NULL").
|
|
||||||
Where(entity.AsynchTaskCol.ExpireAt+" < ?", gtime.Now()).
|
|
||||||
Limit(limit).
|
|
||||||
All()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = r.Structs(&list)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTimeoutTasks 获取超时的排队/执行中任务
|
|
||||||
func (d *taskDao) ListTimeoutTasks(ctx context.Context, timeout time.Duration, limit int) (list []*entity.AsynchTask, err error) {
|
|
||||||
if limit <= 0 {
|
|
||||||
limit = 100
|
|
||||||
}
|
|
||||||
deadline := gtime.New(time.Now().Add(-timeout))
|
|
||||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
|
||||||
WhereIn(entity.AsynchTaskCol.State, []int{0, 1}).
|
|
||||||
Where(entity.AsynchTaskCol.UpdatedAt+" < ?", deadline).
|
|
||||||
Limit(limit).
|
|
||||||
All()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = r.Structs(&list)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// DebugPing 用于启动时检测数据库连通性(可选)
|
|
||||||
func (d *taskDao) DebugPing(ctx context.Context) error {
|
|
||||||
_, err := gfdb.DB(ctx).GetAll(ctx, "SELECT 1")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -150,14 +150,6 @@ func (d *taskDao) UpdateTmpAfterModelGlobal(ctx context.Context, id int64, tmpFi
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *taskDao) SoftDeleteByTaskIDGlobal(ctx context.Context, taskID string) error {
|
|
||||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
|
||||||
fmt.Sprintf(`UPDATE %s SET deleted_at=NOW(), updated_at=NOW() WHERE task_id=? AND deleted_at IS NULL`, public.TableNameTask),
|
|
||||||
taskID,
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||||
_, err := gfdb.DB(ctx).Exec(ctx,
|
_, err := gfdb.DB(ctx).Exec(ctx,
|
||||||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||||
|
|||||||
8
main.go
8
main.go
@@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"model-gateway/model/dto"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -61,7 +62,10 @@ func startAutoRunner(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if _, err := service.AsyncWorker.RunOnce(ctx, batchSize, goroutines); err != nil {
|
if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
||||||
|
BatchSize: batchSize,
|
||||||
|
Goroutines: goroutines,
|
||||||
|
}); err != nil {
|
||||||
g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err)
|
g.Log().Warningf(ctx, "[auto-worker] run once failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -83,7 +87,7 @@ func startAutoRunner(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
service.Cleaner.RunOnce(ctx)
|
_, _ = service.Cleaner.RunOnce(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ type CreateModelReq struct {
|
|||||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||||
|
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||||
|
TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||||
Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
Form any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射"`
|
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||||
@@ -50,11 +52,13 @@ type UpdateModelReq struct {
|
|||||||
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
RequestMapping any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||||
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
ResponseMapping any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||||
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
ResponseBody any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||||
TokenMapping string `p:"tokenMapping" json:"tokenMapping" dc:"token映射(可选更新)"`
|
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||||
|
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||||
|
TokenConfig any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||||
@@ -65,10 +69,18 @@ type UpdateModelReq struct {
|
|||||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UpdateModelRes struct {
|
||||||
|
ID int64 `json:"id,string" dc:"配置ID"`
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteModelReq 删除模型配置
|
// DeleteModelReq 删除模型配置
|
||||||
type DeleteModelReq struct {
|
type DeleteModelReq struct {
|
||||||
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
|
g.Meta `path:"/deleteModel" method:"delete" tags:"模型管理" summary:"删除模型配置" dc:"删除指定ID的模型配置"`
|
||||||
ID string `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
ID int64 `p:"id" json:"id,string" v:"required#id不能为空" dc:"配置ID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeleteModelRes struct {
|
||||||
|
ID int64 `json:"id,string" dc:"配置ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelReq 获取模型配置详情
|
// GetModelReq 获取模型配置详情
|
||||||
@@ -128,7 +140,14 @@ type UpdateChatModelReq struct {
|
|||||||
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
|
g.Meta `path:"/updateChatModel" method:"post" tags:"模型管理" summary:"更新聊天模型" dc:"更新指定模型的聊天模型"`
|
||||||
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
|
Id int64 `p:"id" json:"id" v:"required#model不能为空" dc:"模型id"`
|
||||||
}
|
}
|
||||||
|
type UpdateChatModelRes struct {
|
||||||
|
ID int64 `json:"id,string" dc:"模型ID"`
|
||||||
|
}
|
||||||
|
|
||||||
type GetIsChatModelReq struct {
|
type GetIsChatModelReq struct {
|
||||||
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
|
g.Meta `path:"/getIsChatModel" method:"get" tags:"模型管理" summary:"获取模型是否为聊天模型" dc:"根据模型ID获取是否为聊天模型"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GetIsChatModelRes struct {
|
||||||
|
Model any `json:"model" dc:"模型详情"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,58 +4,62 @@ import "gitea.com/red-future/common/beans"
|
|||||||
|
|
||||||
type asynchModelCol struct {
|
type asynchModelCol struct {
|
||||||
beans.SQLBaseCol
|
beans.SQLBaseCol
|
||||||
ModelName string
|
ModelName string
|
||||||
ModelType string
|
ModelType string
|
||||||
BaseURL string
|
BaseURL string
|
||||||
HttpMethod string
|
HttpMethod string
|
||||||
HeadMsg string
|
HeadMsg string
|
||||||
FormJSON string
|
FormJSON string
|
||||||
RequestMapping string
|
RequestMapping string
|
||||||
ResponseMapping string
|
ResponseMapping string
|
||||||
ResponseBody string
|
ResponseBody string
|
||||||
TokenMapping string
|
ResponseTokenField string
|
||||||
Prompt string
|
Prompt string
|
||||||
IsPrivate string
|
IsPrivate string
|
||||||
IsChatModel string
|
IsChatModel string
|
||||||
ApiKey string
|
ApiKey string
|
||||||
Enabled string
|
Enabled string
|
||||||
MaxConcurrency string
|
MaxConcurrency string
|
||||||
QueueLimit string
|
QueueLimit string
|
||||||
TimeoutSeconds string
|
TimeoutSeconds string
|
||||||
ExpectedSeconds string
|
ExpectedSeconds string
|
||||||
RetryTimes string
|
RetryTimes string
|
||||||
RetryQueueMaxSecs string
|
RetryQueueMaxSecs string
|
||||||
AutoCleanSeconds string
|
AutoCleanSeconds string
|
||||||
Remark string
|
Remark string
|
||||||
IsOwner string
|
IsOwner string
|
||||||
|
OperatorName string
|
||||||
|
TokenConfig string
|
||||||
}
|
}
|
||||||
|
|
||||||
var AsynchModelCol = asynchModelCol{
|
var AsynchModelCol = asynchModelCol{
|
||||||
SQLBaseCol: beans.DefSQLBaseCol,
|
SQLBaseCol: beans.DefSQLBaseCol,
|
||||||
ModelName: "model_name",
|
ModelName: "model_name",
|
||||||
ModelType: "model_type",
|
ModelType: "model_type",
|
||||||
BaseURL: "base_url",
|
BaseURL: "base_url",
|
||||||
HttpMethod: "http_method",
|
HttpMethod: "http_method",
|
||||||
HeadMsg: "head_msg",
|
HeadMsg: "head_msg",
|
||||||
FormJSON: "form_json",
|
FormJSON: "form_json",
|
||||||
RequestMapping: "request_mapping",
|
RequestMapping: "request_mapping",
|
||||||
ResponseMapping: "response_mapping",
|
ResponseMapping: "response_mapping",
|
||||||
ResponseBody: "response_body",
|
ResponseBody: "response_body",
|
||||||
TokenMapping: "token_mapping",
|
ResponseTokenField: "response_token_field",
|
||||||
Prompt: "prompt",
|
Prompt: "prompt",
|
||||||
IsPrivate: "is_private",
|
IsPrivate: "is_private",
|
||||||
IsChatModel: "is_chat_model",
|
IsChatModel: "is_chat_model",
|
||||||
ApiKey: "api_key",
|
ApiKey: "api_key",
|
||||||
Enabled: "enabled",
|
Enabled: "enabled",
|
||||||
MaxConcurrency: "max_concurrency",
|
MaxConcurrency: "max_concurrency",
|
||||||
QueueLimit: "queue_limit",
|
QueueLimit: "queue_limit",
|
||||||
TimeoutSeconds: "timeout_seconds",
|
TimeoutSeconds: "timeout_seconds",
|
||||||
ExpectedSeconds: "expected_seconds",
|
ExpectedSeconds: "expected_seconds",
|
||||||
RetryTimes: "retry_times",
|
RetryTimes: "retry_times",
|
||||||
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
RetryQueueMaxSecs: "retry_queue_max_seconds",
|
||||||
AutoCleanSeconds: "auto_clean_seconds",
|
AutoCleanSeconds: "auto_clean_seconds",
|
||||||
Remark: "remark",
|
Remark: "remark",
|
||||||
IsOwner: "is_owner",
|
IsOwner: "is_owner",
|
||||||
|
OperatorName: "operator_name",
|
||||||
|
TokenConfig: "token_config",
|
||||||
}
|
}
|
||||||
|
|
||||||
// AsynchModel 异步模型配置
|
// AsynchModel 异步模型配置
|
||||||
@@ -70,7 +74,7 @@ type AsynchModel struct {
|
|||||||
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
RequestMapping any `orm:"request_mapping" json:"requestMapping"`
|
||||||
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
ResponseMapping any `orm:"response_mapping" json:"responseMapping"`
|
||||||
ResponseBody any `orm:"response_body" json:"responseBody"`
|
ResponseBody any `orm:"response_body" json:"responseBody"`
|
||||||
TokenMapping string `orm:"token_mapping" json:"tokenMapping"`
|
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||||
Prompt string `orm:"prompt" json:"prompt"`
|
Prompt string `orm:"prompt" json:"prompt"`
|
||||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
@@ -84,5 +88,7 @@ type AsynchModel struct {
|
|||||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
Remark string `orm:"remark" json:"remark"`
|
||||||
IsOwner *int `json:"isOwner" orm:"is_owner"` // 1=当前用户创建的,0=超级管理员的
|
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||||
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
|
TokenConfig any `orm:"token_config" json:"tokenConfig"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
package entity
|
|
||||||
|
|
||||||
import "gitea.com/red-future/common/beans"
|
|
||||||
|
|
||||||
type asynchModelTypeCol struct {
|
|
||||||
beans.SQLBaseCol
|
|
||||||
TypeID string
|
|
||||||
TypeName string
|
|
||||||
Remark string
|
|
||||||
}
|
|
||||||
|
|
||||||
var AsynchModelTypeCol = asynchModelTypeCol{
|
|
||||||
SQLBaseCol: beans.DefSQLBaseCol,
|
|
||||||
TypeID: "type_id",
|
|
||||||
TypeName: "type_name",
|
|
||||||
Remark: "remark",
|
|
||||||
}
|
|
||||||
|
|
||||||
// AsynchModelType 模型类型(图片/音频/视频等)
|
|
||||||
type AsynchModelType struct {
|
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
|
||||||
TypeID int `orm:"type_id" json:"typeId"`
|
|
||||||
TypeName string `orm:"type_name" json:"type"`
|
|
||||||
Remark string `orm:"remark" json:"remark"`
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -2,8 +2,10 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"model-gateway/model/dto"
|
||||||
|
|
||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
@@ -34,9 +36,12 @@ type AutoTuneResult struct {
|
|||||||
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
||||||
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
||||||
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
||||||
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
|
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||||
if windowSeconds <= 0 {
|
if req == nil {
|
||||||
windowSeconds = 3600
|
return nil, errors.New("request cannot be nil")
|
||||||
|
}
|
||||||
|
if req.WindowSeconds <= 0 {
|
||||||
|
req.WindowSeconds = 3600 // 默认1小时
|
||||||
}
|
}
|
||||||
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
||||||
var modelRows []*entity.AsynchModel
|
var modelRows []*entity.AsynchModel
|
||||||
@@ -68,7 +73,7 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(modelMap) == 0 {
|
if len(modelMap) == 0 {
|
||||||
return []AutoTuneResult{}, nil
|
return nil, errors.New("no models found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
||||||
@@ -89,7 +94,7 @@ SELECT model_name,
|
|||||||
AND finished_at IS NOT NULL
|
AND finished_at IS NOT NULL
|
||||||
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
||||||
GROUP BY model_name`, public.TableNameTask)
|
GROUP BY model_name`, public.TableNameTask)
|
||||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
|
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -189,6 +194,8 @@ SELECT model_name,
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
|
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
|
||||||
return out, nil
|
return &dto.AutoTuneRes{
|
||||||
|
List: out,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,67 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/http"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
)
|
|
||||||
|
|
||||||
// triggerCallback 任务成功后的回调:
|
|
||||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
|
||||||
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
|
||||||
callbackURL := t.BizName + t.CallbackURL
|
|
||||||
headers := forwardHeaders(ctx)
|
|
||||||
var req struct{}
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"task_id": t.TaskID,
|
|
||||||
"state": t.State,
|
|
||||||
"oss_file": t.OssFile,
|
|
||||||
"file_type": t.FileType,
|
|
||||||
"text": t.TextResult,
|
|
||||||
"error_msg": t.ErrorMsg,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
|
||||||
t.TaskID, callbackURL, len(headers), len(jsonData))
|
|
||||||
|
|
||||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// triggerPromptsCallback 任务成功后的提示词回调
|
|
||||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
|
||||||
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
|
||||||
callbackURL := "prompts-core/session/sessionCallback"
|
|
||||||
headers := forwardHeaders(ctx)
|
|
||||||
var req struct{}
|
|
||||||
payload := map[string]interface{}{
|
|
||||||
"epicycleId": epicycleId,
|
|
||||||
"text": t.TextResult,
|
|
||||||
}
|
|
||||||
jsonData, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
|
||||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
|
||||||
|
|
||||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
|
||||||
if err != nil {
|
|
||||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
|
||||||
}
|
|
||||||
@@ -2,6 +2,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"model-gateway/model/dto"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
@@ -14,14 +16,14 @@ var Cleaner = &cleaner{}
|
|||||||
type cleaner struct{}
|
type cleaner struct{}
|
||||||
|
|
||||||
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
||||||
func (c *cleaner) RunOnce(ctx context.Context) {
|
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
|
||||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range expired {
|
for _, t := range expired {
|
||||||
deleteTmpResult(t.TmpFile)
|
_ = os.Remove(t.TmpFile)
|
||||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||||
@@ -82,11 +84,14 @@ func (c *cleaner) RunOnce(ctx context.Context) {
|
|||||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range exhausted {
|
for _, t := range exhausted {
|
||||||
deleteTmpResult(t.TmpFile)
|
_ = os.Remove(t.TmpFile)
|
||||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||||
}
|
}
|
||||||
|
return &dto.CleanWorkRes{
|
||||||
|
Ok: true,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,47 +1 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
|
||||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
|
||||||
if len(data) == 0 {
|
|
||||||
return "application/octet-stream", ""
|
|
||||||
}
|
|
||||||
ct := http.DetectContentType(data)
|
|
||||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
|
||||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
|
||||||
ct = strings.TrimSpace(ct[:idx])
|
|
||||||
}
|
|
||||||
switch ct {
|
|
||||||
case "audio/mpeg":
|
|
||||||
return ct, ".mp3"
|
|
||||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
|
||||||
return ct, ".wav"
|
|
||||||
case "video/mp4":
|
|
||||||
return ct, ".mp4"
|
|
||||||
case "image/png":
|
|
||||||
return ct, ".png"
|
|
||||||
case "image/jpeg":
|
|
||||||
return ct, ".jpg"
|
|
||||||
case "application/pdf":
|
|
||||||
return ct, ".pdf"
|
|
||||||
case "text/plain":
|
|
||||||
return ct, ".txt"
|
|
||||||
case "application/json":
|
|
||||||
return ct, ".json"
|
|
||||||
default:
|
|
||||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
|
||||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
|
||||||
sub := parts[1]
|
|
||||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
|
||||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
|
||||||
sub = strings.TrimSpace(sub[:idx])
|
|
||||||
}
|
|
||||||
return ct, "." + sub
|
|
||||||
}
|
|
||||||
return ct, ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
171
service/gateway/gateway_http_service.go
Normal file
171
service/gateway/gateway_http_service.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"mime/multipart"
|
||||||
|
"model-gateway/common/util"
|
||||||
|
"model-gateway/model/entity"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
commonHttp "gitea.com/red-future/common/http"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/guid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type uploadFileResponse struct {
|
||||||
|
FileURL string `json:"fileURL"` // 文件 URL
|
||||||
|
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||||
|
FileName string `json:"fileName"` // 文件名
|
||||||
|
FileFormat string `json:"fileFormat"` // 文件格式
|
||||||
|
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||||
|
}
|
||||||
|
|
||||||
|
func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||||
|
// multipart
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
|
ext := fileExt
|
||||||
|
if ext == "" {
|
||||||
|
ext = ".bin"
|
||||||
|
}
|
||||||
|
if ext[0] != '.' {
|
||||||
|
ext = "." + ext
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||||
|
part, err := writer.CreateFormFile("file", filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err = part.Write(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
fullURL := "oss/file/uploadFile"
|
||||||
|
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||||
|
|
||||||
|
var resp uploadFileResponse
|
||||||
|
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||||
|
return resp.FileURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerCallback 任务成功后的回调:
|
||||||
|
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||||
|
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
var req struct{}
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"task_id": t.TaskID,
|
||||||
|
"state": t.State,
|
||||||
|
"oss_file": t.OssFile,
|
||||||
|
"file_type": t.FileType,
|
||||||
|
"text": t.TextResult,
|
||||||
|
"error_msg": t.ErrorMsg,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||||
|
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
|
||||||
|
|
||||||
|
err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||||
|
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||||
|
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||||
|
callbackURL := "prompts-core/session/sessionCallback"
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
var req struct{}
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"epicycleId": epicycleId,
|
||||||
|
"text": t.TextResult,
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||||
|
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||||
|
|
||||||
|
err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||||
|
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||||
|
headers := util.ForwardHeaders(ctx)
|
||||||
|
var r = make(map[string]bool)
|
||||||
|
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return r["isSuperAdmin"], err
|
||||||
|
}
|
||||||
|
|
||||||
|
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
|
||||||
|
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
|
||||||
|
// if callbackURL == "" {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
|
||||||
|
// if task == nil {
|
||||||
|
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
|
||||||
|
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
|
||||||
|
// for i := range detailList {
|
||||||
|
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 构建与查询接口一致的 taskInfo
|
||||||
|
// taskInfo := dao.EntityToItem(task)
|
||||||
|
//
|
||||||
|
// // 兼容历史数据: 从 result 中补全 scenes 等字段
|
||||||
|
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
|
||||||
|
//
|
||||||
|
// payload := dto.CallbackPayload{
|
||||||
|
// TaskInfo: taskInfo,
|
||||||
|
// DetailList: detailItems,
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// body, _ := json.Marshal(payload)
|
||||||
|
//
|
||||||
|
// // 透传调用方的用户信息
|
||||||
|
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
|
||||||
|
//
|
||||||
|
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
|
||||||
|
// req.Header.Set("Content-Type", "application/json")
|
||||||
|
// req.Header.Set("X-User-Info", string(userJSON))
|
||||||
|
//
|
||||||
|
// resp, reqErr := http.DefaultClient.Do(req)
|
||||||
|
// if reqErr != nil {
|
||||||
|
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
//
|
||||||
|
// respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
|
||||||
|
//}
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/utils"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
)
|
|
||||||
|
|
||||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
|
||||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
|
||||||
func asyncCtx(ctx context.Context) context.Context {
|
|
||||||
asyncCtx := context.WithoutCancel(ctx)
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
if token := r.Header.Get("Authorization"); token != "" {
|
|
||||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
|
||||||
}
|
|
||||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
|
||||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
|
||||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
|
||||||
}
|
|
||||||
return asyncCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
|
||||||
func forwardHeaders(ctx context.Context) map[string]string {
|
|
||||||
headers := make(map[string]string)
|
|
||||||
|
|
||||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
|
||||||
headers["Authorization"] = token
|
|
||||||
}
|
|
||||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
|
||||||
headers["X-User-Info"] = x
|
|
||||||
}
|
|
||||||
|
|
||||||
// 兜底:从请求头拿
|
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
|
||||||
if headers["Authorization"] == "" {
|
|
||||||
if token := r.Header.Get("Authorization"); token != "" {
|
|
||||||
headers["Authorization"] = token
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if headers["X-User-Info"] == "" {
|
|
||||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
|
||||||
headers["X-User-Info"] = userInfo
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return headers
|
|
||||||
}
|
|
||||||
@@ -3,13 +3,15 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"model-gateway/common/util"
|
||||||
|
"model-gateway/consts/public"
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
"model-gateway/service/gateway"
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
"gitea.com/red-future/common/beans"
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"gitea.com/red-future/common/http"
|
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/database/gdb"
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
@@ -20,28 +22,20 @@ var Model = &modelService{}
|
|||||||
|
|
||||||
type modelService struct{}
|
type modelService struct{}
|
||||||
|
|
||||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
|
||||||
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
|
||||||
headers := forwardHeaders(ctx)
|
|
||||||
var r = make(map[string]bool)
|
|
||||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return r["isSuperAdmin"], err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||||
// 获取当前会话模型
|
// 获取当前会话模型
|
||||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||||
var model *entity.AsynchModel
|
var model *entity.AsynchModel
|
||||||
model, err = dao.Model.GetByIsChatModel(ctx)
|
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// 如果有会话模型,那就改变为 0
|
// 如果有会话模型,那就改变为 0
|
||||||
if model != nil {
|
if model != nil {
|
||||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
ID: model.Id,
|
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||||
IsChatModel: gconv.PtrInt(0),
|
IsChatModel: gconv.PtrInt(0),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
admin, err := s.IsSuperAdmin(ctx)
|
admin, err := gateway.IsSuperAdmin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if admin {
|
if admin {
|
||||||
req.IsOwner = gconv.PtrInt(0)
|
req.IsOwner = gconv.PtrInt(0)
|
||||||
}
|
}
|
||||||
id, err := dao.Model.Insert(ctx, req)
|
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
ModelType: req.ModelType,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
HttpMethod: req.HttpMethod,
|
||||||
|
HeadMsg: req.HeadMsg,
|
||||||
|
Form: req.Form,
|
||||||
|
RequestMapping: req.RequestMapping,
|
||||||
|
ResponseMapping: req.ResponseMapping,
|
||||||
|
ResponseBody: req.ResponseBody,
|
||||||
|
ResponseTokenField: req.ResponseTokenField,
|
||||||
|
IsPrivate: req.IsPrivate,
|
||||||
|
IsChatModel: req.IsChatModel,
|
||||||
|
ApiKey: req.ApiKey,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
MaxConcurrency: req.MaxConcurrency,
|
||||||
|
QueueLimit: req.QueueLimit,
|
||||||
|
TimeoutSeconds: req.TimeoutSeconds,
|
||||||
|
ExpectedSeconds: req.ExpectedSeconds,
|
||||||
|
RetryTimes: req.RetryTimes,
|
||||||
|
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||||
|
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||||
|
Remark: req.Remark,
|
||||||
|
IsOwner: req.IsOwner,
|
||||||
|
OperatorName: req.OperatorName,
|
||||||
|
TokenConfig: req.TokenConfig,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
|||||||
//根据当前 isChatModel 来判断是否更新模型
|
//根据当前 isChatModel 来判断是否更新模型
|
||||||
if req.IsChatModel == gconv.PtrInt(1) {
|
if req.IsChatModel == gconv.PtrInt(1) {
|
||||||
//判断当前用户是否有会话模型
|
//判断当前用户是否有会话模型
|
||||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
admin, err := s.IsSuperAdmin(ctx)
|
admin, err := gateway.IsSuperAdmin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if admin {
|
if admin {
|
||||||
req.IsOwner = gconv.PtrInt(0)
|
req.IsOwner = gconv.PtrInt(0)
|
||||||
_, err = dao.Model.Update(ctx, req)
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
ModelType: req.ModelType,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
HttpMethod: req.HttpMethod,
|
||||||
|
HeadMsg: req.HeadMsg,
|
||||||
|
Form: req.Form,
|
||||||
|
RequestMapping: req.RequestMapping,
|
||||||
|
ResponseMapping: req.ResponseMapping,
|
||||||
|
ResponseBody: req.ResponseBody,
|
||||||
|
ResponseTokenField: req.ResponseTokenField,
|
||||||
|
IsPrivate: req.IsPrivate,
|
||||||
|
IsChatModel: req.IsChatModel,
|
||||||
|
ApiKey: req.ApiKey,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
MaxConcurrency: req.MaxConcurrency,
|
||||||
|
QueueLimit: req.QueueLimit,
|
||||||
|
TimeoutSeconds: req.TimeoutSeconds,
|
||||||
|
ExpectedSeconds: req.ExpectedSeconds,
|
||||||
|
RetryTimes: req.RetryTimes,
|
||||||
|
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||||
|
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||||
|
Remark: req.Remark,
|
||||||
|
IsOwner: req.IsOwner,
|
||||||
|
OperatorName: req.OperatorName,
|
||||||
|
TokenConfig: req.TokenConfig,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *beans.User
|
|
||||||
user, err = utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||||||
var count int
|
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||||
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
ID: req.ID,
|
|
||||||
Creator: user.UserName,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count == 0 {
|
if model.TenantId == 1 {
|
||||||
insertDto := new(dto.CreateModelReq)
|
insertDto := new(dto.CreateModelReq)
|
||||||
err = gconv.Struct(req, insertDto)
|
err = gconv.Struct(req, insertDto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = dao.Model.Insert(ctx, insertDto)
|
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
ModelType: req.ModelType,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
HttpMethod: req.HttpMethod,
|
||||||
|
HeadMsg: req.HeadMsg,
|
||||||
|
Form: req.Form,
|
||||||
|
RequestMapping: req.RequestMapping,
|
||||||
|
ResponseMapping: req.ResponseMapping,
|
||||||
|
ResponseBody: req.ResponseBody,
|
||||||
|
ResponseTokenField: req.ResponseTokenField,
|
||||||
|
IsPrivate: req.IsPrivate,
|
||||||
|
IsChatModel: req.IsChatModel,
|
||||||
|
ApiKey: req.ApiKey,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
MaxConcurrency: req.MaxConcurrency,
|
||||||
|
QueueLimit: req.QueueLimit,
|
||||||
|
TimeoutSeconds: req.TimeoutSeconds,
|
||||||
|
ExpectedSeconds: req.ExpectedSeconds,
|
||||||
|
RetryTimes: req.RetryTimes,
|
||||||
|
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||||
|
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||||
|
Remark: req.Remark,
|
||||||
|
IsOwner: req.IsOwner,
|
||||||
|
OperatorName: req.OperatorName,
|
||||||
|
TokenConfig: req.TokenConfig,
|
||||||
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = dao.Model.Update(ctx, req)
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
ModelType: req.ModelType,
|
||||||
|
BaseURL: req.BaseURL,
|
||||||
|
HttpMethod: req.HttpMethod,
|
||||||
|
HeadMsg: req.HeadMsg,
|
||||||
|
Form: req.Form,
|
||||||
|
RequestMapping: req.RequestMapping,
|
||||||
|
ResponseMapping: req.ResponseMapping,
|
||||||
|
ResponseBody: req.ResponseBody,
|
||||||
|
ResponseTokenField: req.ResponseTokenField,
|
||||||
|
IsPrivate: req.IsPrivate,
|
||||||
|
IsChatModel: req.IsChatModel,
|
||||||
|
ApiKey: req.ApiKey,
|
||||||
|
Enabled: req.Enabled,
|
||||||
|
MaxConcurrency: req.MaxConcurrency,
|
||||||
|
QueueLimit: req.QueueLimit,
|
||||||
|
TimeoutSeconds: req.TimeoutSeconds,
|
||||||
|
ExpectedSeconds: req.ExpectedSeconds,
|
||||||
|
RetryTimes: req.RetryTimes,
|
||||||
|
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||||
|
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||||
|
Remark: req.Remark,
|
||||||
|
IsOwner: req.IsOwner,
|
||||||
|
OperatorName: req.OperatorName,
|
||||||
|
TokenConfig: req.TokenConfig,
|
||||||
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||||
_, err := dao.Model.DeleteByID(ctx, id)
|
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||||
model, err := dao.Model.Get(ctx, id)
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
model.Form = ParseJSONField(model.Form)
|
model.Form = util.ParseJSONField(model.Form)
|
||||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||||
return model, nil
|
return &dto.GetModelRes{
|
||||||
|
Model: model,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||||
var models []*entity.AsynchModel
|
var models []*entity.AsynchModel
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
admin, err := s.IsSuperAdmin(ctx)
|
admin, err := gateway.IsSuperAdmin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []
|
|||||||
var user *beans.User
|
var user *beans.User
|
||||||
user, err = utils.GetUserInfo(ctx)
|
user, err = utils.GetUserInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Creator = user.UserName
|
req.Creator = user.UserName
|
||||||
|
|
||||||
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
|
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理列表中每条记录的 JSONB 字段
|
// 处理列表中每条记录的 JSONB 字段
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
m.Form = ParseJSONField(m.Form)
|
m.Form = util.ParseJSONField(m.Form)
|
||||||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
|
||||||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
|
||||||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
|
||||||
}
|
}
|
||||||
return models, total, nil
|
return &dto.ListModelRes{
|
||||||
|
List: models,
|
||||||
|
Total: total,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||||
typeMap := make(map[int]string)
|
// 返回副本,避免外部修改
|
||||||
|
types := make(map[int]string, len(public.ModelTypeName))
|
||||||
// 读取配置
|
for k, v := range public.ModelTypeName {
|
||||||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
types[k] = v
|
||||||
for k, v := range configMap {
|
|
||||||
typeID := gconv.Int(k)
|
|
||||||
typeName := gconv.String(v)
|
|
||||||
if typeID > 0 && typeName != "" {
|
|
||||||
typeMap[typeID] = typeName
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// 如果配置为空,使用默认值
|
return &dto.TypeItem{
|
||||||
if len(typeMap) == 0 {
|
Type: types,
|
||||||
typeMap = map[int]string{
|
}, nil
|
||||||
1: "推理模型",
|
|
||||||
2: "图片模型",
|
|
||||||
3: "音频模型",
|
|
||||||
4: "向量化模型",
|
|
||||||
5: "全模态模型",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return typeMap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||||
// 校验新会话模型是否存在
|
// 校验新会话模型是否存在
|
||||||
newModel, err := dao.Model.Get(ctx, req.Id)
|
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if newModel == nil {
|
if newModel == nil {
|
||||||
return errors.New("新会话模型不存在")
|
return errors.New("新会话模型不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前用户会话模型
|
// 获取当前用户会话模型
|
||||||
currentModel, err := dao.Model.GetByIsChatModel(ctx)
|
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
|||||||
|
|
||||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||||
if currentModel.Id != req.Id {
|
if currentModel.Id != req.Id {
|
||||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
ID: currentModel.Id,
|
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||||
IsChatModel: gconv.PtrInt(0),
|
IsChatModel: gconv.PtrInt(0),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 设置当前为会话模型(设为1)
|
// 设置当前为会话模型(设为1)
|
||||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
ID: req.Id,
|
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||||
IsChatModel: gconv.PtrInt(1),
|
IsChatModel: gconv.PtrInt(1),
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
@@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if model == nil {
|
if model == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
model.Form = ParseJSONField(model.Form)
|
model.Form = util.ParseJSONField(model.Form)
|
||||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||||
return model, nil
|
return &dto.GetIsChatModelRes{
|
||||||
|
Model: model,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import "github.com/gogf/gf/v2/util/gconv"
|
|
||||||
|
|
||||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
|
||||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
|
||||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
|
||||||
if v == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
m := gconv.Map(v)
|
|
||||||
if len(m) == 0 {
|
|
||||||
return v, nil
|
|
||||||
}
|
|
||||||
if h, ok := m["headers"]; ok {
|
|
||||||
headers = gconv.MapStrStr(h)
|
|
||||||
}
|
|
||||||
if p, ok := m["payload"]; ok {
|
|
||||||
payload = p
|
|
||||||
} else {
|
|
||||||
payload = v
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StorageService 结果存储(OSS/MinIO)抽象
|
|
||||||
type StorageService interface {
|
|
||||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
|
||||||
var Storage StorageService = &ossStorage{}
|
|
||||||
|
|
||||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"mime/multipart"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
|
|
||||||
commonHttp "gitea.com/red-future/common/http"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
"github.com/gogf/gf/v2/util/guid"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
|
||||||
type ossStorage struct{}
|
|
||||||
|
|
||||||
type uploadFileResponse struct {
|
|
||||||
FileURL string `json:"fileURL"` // 文件 URL
|
|
||||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
|
||||||
FileName string `json:"fileName"` // 文件名
|
|
||||||
FileFormat string `json:"fileFormat"` // 文件格式
|
|
||||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
|
||||||
// multipart
|
|
||||||
body := &bytes.Buffer{}
|
|
||||||
writer := multipart.NewWriter(body)
|
|
||||||
|
|
||||||
ext := fileExt
|
|
||||||
if ext == "" {
|
|
||||||
ext = ".bin"
|
|
||||||
}
|
|
||||||
if ext[0] != '.' {
|
|
||||||
ext = "." + ext
|
|
||||||
}
|
|
||||||
|
|
||||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
|
||||||
part, err := writer.CreateFormFile("file", filename)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if _, err := part.Write(data); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
contentType := writer.FormDataContentType()
|
|
||||||
if err := writer.Close(); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
headers := forwardHeaders(ctx)
|
|
||||||
headers["Content-Type"] = contentType
|
|
||||||
|
|
||||||
fullURL := "oss/file/uploadFile"
|
|
||||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
|
||||||
|
|
||||||
var resp uploadFileResponse
|
|
||||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
|
||||||
return resp.FileURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
|
||||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
|
||||||
if headers == nil {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
|
||||||
ctx = context.WithValue(ctx, "token", v)
|
|
||||||
}
|
|
||||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
|
||||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
|
||||||
}
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"model-gateway/common/util"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
@@ -21,13 +21,13 @@ var Task = &taskService{}
|
|||||||
type taskService struct{}
|
type taskService struct{}
|
||||||
|
|
||||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||||
fmt.Printf("打印请求:%+v", req)
|
|
||||||
startAt := time.Now()
|
startAt := time.Now()
|
||||||
// 固化 token/user 等信息
|
// 固化 token/user 等信息
|
||||||
ctx = asyncCtx(ctx)
|
ctx = util.AsyncCtx(ctx)
|
||||||
|
|
||||||
// 1) 检查模型配置
|
// 1) 检查模型配置
|
||||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
ModelName: req.ModelName,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||||
storedPayload := map[string]any{
|
storedPayload := map[string]any{
|
||||||
"payload": req.RequestPayload,
|
"payload": req.RequestPayload,
|
||||||
"headers": forwardHeaders(ctx),
|
"headers": util.ForwardHeaders(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
t := &entity.AsynchTask{
|
t := &entity.AsynchTask{
|
||||||
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
|||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
tryRun := func() bool {
|
tryRun := func() bool {
|
||||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||||
|
TaskID: taskID,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||||
return true
|
return true
|
||||||
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
|||||||
}
|
}
|
||||||
switch t.State {
|
switch t.State {
|
||||||
case 0:
|
case 0:
|
||||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||||
} else {
|
} else {
|
||||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||||
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||||
|
TaskID: taskID,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 按模型配置决定保留时间
|
// 按模型配置决定保留时间
|
||||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
ModelName: t.ModelName,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
|
||||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
|
||||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if ext == "" {
|
|
||||||
ext = ".bin"
|
|
||||||
}
|
|
||||||
if ext[0] != '.' {
|
|
||||||
ext = "." + ext
|
|
||||||
}
|
|
||||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
|
||||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return path, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadTmpResult(path string) ([]byte, error) {
|
|
||||||
return os.ReadFile(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteTmpResult(path string) {
|
|
||||||
if path == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_ = os.Remove(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
113
service/utils.go
113
service/utils.go
@@ -1,113 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
|
||||||
)
|
|
||||||
|
|
||||||
func normalizeFormValue(v any) any {
|
|
||||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
|
||||||
if v == nil {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
switch t := v.(type) {
|
|
||||||
case string:
|
|
||||||
s := strings.TrimSpace(t)
|
|
||||||
if s == "" {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
return normalizeFormValueFromJSONString(s)
|
|
||||||
case []byte:
|
|
||||||
if len(t) == 0 {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
return normalizeFormValueFromJSONBytes(t)
|
|
||||||
case *gvar.Var:
|
|
||||||
// goframe 常见的 DB 返回类型
|
|
||||||
if t == nil {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
b := t.Bytes()
|
|
||||||
if len(b) > 0 {
|
|
||||||
return normalizeFormValueFromJSONBytes(b)
|
|
||||||
}
|
|
||||||
s := strings.TrimSpace(t.String())
|
|
||||||
if s == "" {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
return normalizeFormValueFromJSONString(s)
|
|
||||||
default:
|
|
||||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
|
||||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
|
||||||
if b := vb.Bytes(); len(b) > 0 {
|
|
||||||
return normalizeFormValueFromJSONBytes(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if vs, ok := v.(interface{ String() string }); ok {
|
|
||||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
|
||||||
return normalizeFormValueFromJSONString(s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 已经是 []any / map[string]any 等结构
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
|
||||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
|
||||||
func normalizeFormValueFromJSONString(s string) any {
|
|
||||||
var out any
|
|
||||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
|
||||||
if inner, ok := out.(string); ok {
|
|
||||||
inner = strings.TrimSpace(inner)
|
|
||||||
if inner == "" {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
|
||||||
var out2 any
|
|
||||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
|
||||||
return out2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
|
||||||
var out any
|
|
||||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
|
||||||
return []any{}
|
|
||||||
}
|
|
||||||
// bytes 解出来也可能是 string(同上)
|
|
||||||
if inner, ok := out.(string); ok {
|
|
||||||
return normalizeFormValueFromJSONString(inner)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func ParseJSONField(field any) any {
|
|
||||||
var v *gvar.Var
|
|
||||||
switch val := field.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
v = val
|
|
||||||
default:
|
|
||||||
return field
|
|
||||||
}
|
|
||||||
|
|
||||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
str := v.String()
|
|
||||||
var result any
|
|
||||||
if json.Unmarshal([]byte(str), &result) == nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
return str
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,13 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"model-gateway/common/util"
|
||||||
|
"model-gateway/model/dto"
|
||||||
|
"model-gateway/service/gateway"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@@ -23,24 +29,23 @@ type asyncWorker struct {
|
|||||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||||
// - batchSize: 本次抢占数量
|
// - batchSize: 本次抢占数量
|
||||||
// - goroutines: 本次并发数(协程池大小)
|
// - goroutines: 本次并发数(协程池大小)
|
||||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||||
if batchSize <= 0 {
|
if req.BatchSize <= 0 {
|
||||||
batchSize = 10
|
req.BatchSize = 10
|
||||||
}
|
}
|
||||||
if goroutines <= 0 {
|
if req.Goroutines <= 0 {
|
||||||
goroutines = 1
|
req.Goroutines = 1
|
||||||
}
|
}
|
||||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(tasks) == 0 {
|
if len(tasks) == 0 {
|
||||||
return 0, nil
|
return nil, errors.New("no task to run")
|
||||||
}
|
}
|
||||||
pool := grpool.New(goroutines)
|
pool := grpool.New(req.Goroutines)
|
||||||
defer pool.Close()
|
defer pool.Close()
|
||||||
|
claimed := len(tasks)
|
||||||
claimed = len(tasks)
|
|
||||||
done := make(chan struct{}, claimed)
|
done := make(chan struct{}, claimed)
|
||||||
for _, t := range tasks {
|
for _, t := range tasks {
|
||||||
task := t
|
task := t
|
||||||
@@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
|||||||
for i := 0; i < claimed; i++ {
|
for i := 0; i < claimed; i++ {
|
||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
return claimed, nil
|
return &dto.RunWorkRes{
|
||||||
|
Claimed: claimed,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||||
@@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
|
|||||||
|
|
||||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
payload, headers := util.ParseStoredPayload(t.RequestPayload)
|
||||||
if len(headers) > 0 {
|
if len(headers) > 0 {
|
||||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
ctx = util.SetTaskHeadersToCtx(ctx, headers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1) 拉取模型配置
|
// 1) 拉取模型配置
|
||||||
@@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
// ============ 失败回调 ============
|
// ============ 失败回调 ============
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = err.Error()
|
t.ErrorMsg = err.Error()
|
||||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
// ================================
|
// ================================
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
// ============ 失败回调 ============
|
// ============ 失败回调 ============
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = errMsg
|
t.ErrorMsg = errMsg
|
||||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
// ================================
|
// ================================
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
// ============ 失败回调 ============
|
// ============ 失败回调 ============
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = err.Error()
|
t.ErrorMsg = err.Error()
|
||||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
// ================================
|
// ================================
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
|
|
||||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||||
data, err = loadTmpResult(t.TmpFile)
|
data, err = os.ReadFile(t.TmpFile)
|
||||||
if err == nil && len(data) > 0 {
|
if err == nil && len(data) > 0 {
|
||||||
contentType, ext = DetectFileType(data)
|
contentType, ext = util.DetectFileType(data)
|
||||||
} else {
|
} else {
|
||||||
data = nil
|
data = nil
|
||||||
}
|
}
|
||||||
@@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
// ============ 失败回调 ============
|
// ============ 失败回调 ============
|
||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = err.Error()
|
t.ErrorMsg = err.Error()
|
||||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
// ================================
|
// ================================
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
contentType, ext = DetectFileType(data)
|
contentType, ext = util.DetectFileType(data)
|
||||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||||
textResult = string(data)
|
textResult = string(data)
|
||||||
}
|
}
|
||||||
@@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4) 存储 OSS
|
// 4) 存储 OSS
|
||||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||||
@@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
if fileType == "" {
|
if fileType == "" {
|
||||||
fileType = contentType
|
fileType = contentType
|
||||||
}
|
}
|
||||||
if err := dao.Task.UpdateSuccessGlobal(
|
if err = dao.Task.UpdateSuccessGlobal(
|
||||||
ctx,
|
ctx,
|
||||||
t.Id,
|
t.Id,
|
||||||
ossURL,
|
ossURL,
|
||||||
@@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
textResult,
|
textResult,
|
||||||
int64(len(data)),
|
int64(len(data)),
|
||||||
nil,
|
nil,
|
||||||
GetExpendTokens(m.TokenMapping, textResult),
|
GetExpendTokens(m.ResponseTokenField, textResult),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
|||||||
t.FileType = fileType
|
t.FileType = fileType
|
||||||
t.TextResult = textResult
|
t.TextResult = textResult
|
||||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||||
if epicycleId != 0 {
|
if epicycleId != 0 {
|
||||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 成功后清理临时文件
|
// 成功后清理临时文件
|
||||||
deleteTmpResult(t.TmpFile)
|
_ = os.Remove(t.TmpFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||||
|
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||||
|
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if ext == "" {
|
||||||
|
ext = ".bin"
|
||||||
|
}
|
||||||
|
if ext[0] != '.' {
|
||||||
|
ext = "." + ext
|
||||||
|
}
|
||||||
|
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||||
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||||
@@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int {
|
|||||||
value := gjson.Get(textResult, tokenMapping)
|
value := gjson.Get(textResult, tokenMapping)
|
||||||
if value.Exists() {
|
if value.Exists() {
|
||||||
return int(value.Int())
|
return int(value.Int())
|
||||||
} else {
|
|
||||||
return len(textResult)
|
|
||||||
}
|
}
|
||||||
|
return len(textResult)
|
||||||
}
|
}
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
28
update.sql
28
update.sql
@@ -40,9 +40,18 @@ CREATE TABLE IF NOT EXISTS asynch_models (
|
|||||||
retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)
|
retry_queue_max_seconds INT NOT NULL DEFAULT 600, -- 失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)
|
||||||
auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理)
|
auto_clean_seconds INT NOT NULL DEFAULT 86400, -- 已下载(state=4 后的保留时间(秒),到期清理)
|
||||||
remark TEXT DEFAULT '' -- 备注
|
remark TEXT DEFAULT '' -- 备注
|
||||||
token_mapping VARCHAR(128) NOT NULL DEFAULT ''; -- token 映射
|
response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射
|
||||||
);
|
operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称
|
||||||
|
token_config JSONB NOT NULL DEFAULT '{
|
||||||
|
"zh_ratio": 1.0,
|
||||||
|
"en_ratio": 1.3,
|
||||||
|
"space_ratio": 0.1,
|
||||||
|
"punctuation_ratio": 0.1,
|
||||||
|
"max_window_size": 8192,
|
||||||
|
"reserve_ratio": 0.2,
|
||||||
|
"min_reserve": 512,
|
||||||
|
}'::jsonb -- Token配置
|
||||||
|
);
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
|
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_creator_chat ON asynch_models(tenant_id, creator) WHERE is_chat_model = 1 AND deleted_at IS NULL;
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
|
CREATE UNIQUE INDEX IF NOT EXISTS uk_asynch_models_tenant_model_name ON asynch_models(tenant_id, creator, model_name);
|
||||||
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
|
CREATE INDEX IF NOT EXISTS idx_asynch_models_tenant_id ON asynch_models(tenant_id);
|
||||||
@@ -83,8 +92,17 @@ COMMENT ON COLUMN asynch_models.retry_times IS '失败重试次数';
|
|||||||
COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)';
|
COMMENT ON COLUMN asynch_models.retry_queue_max_seconds IS '失败重试最大排队时间(秒 0=插队到队首;>0=排队超过该时间后插队,否则仍到队尾)';
|
||||||
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)';
|
COMMENT ON COLUMN asynch_models.auto_clean_seconds IS '已下载(state=4 后的保留时间(秒),到期清理)';
|
||||||
COMMENT ON COLUMN asynch_models.remark IS '备注';
|
COMMENT ON COLUMN asynch_models.remark IS '备注';
|
||||||
COMMENT ON COLUMN asynch_models.token_mapping IS 'token映射';
|
COMMENT ON COLUMN asynch_models.response_token_field IS '响应中消耗token的字段映射';
|
||||||
|
COMMENT ON COLUMN asynch_models.operator_name IS '运营商名称';
|
||||||
|
COMMENT ON COLUMN asynch_models.token_config IS '{
|
||||||
|
"zh_ratio": 1.0, // 中文字符→token系数
|
||||||
|
"en_ratio": 1.3, // 英文单词→token系数
|
||||||
|
"space_ratio": 0.1, // 空格系数
|
||||||
|
"punctuation_ratio": 0.1, // 标点系数
|
||||||
|
"max_window_size": 8192, // 模型最大窗口
|
||||||
|
"reserve_ratio": 0.2, // 预留回复空间比例
|
||||||
|
"min_reserve": 512, // 最少预留token数
|
||||||
|
}';
|
||||||
|
|
||||||
|
|
||||||
-- =========================
|
-- =========================
|
||||||
|
|||||||
Reference in New Issue
Block a user