refactor(service): 重构服务模块结构并优化模型配置

This commit is contained in:
2026-05-29 17:54:19 +08:00
parent e487b4bb5e
commit d409b84b58
24 changed files with 943 additions and 1158 deletions

10
common/util/convert.go Normal file
View File

@@ -0,0 +1,10 @@
package util
import "github.com/gogf/gf/v2/util/gconv"
// ConvertTo 转换为指定类型
func ConvertTo[T any](v interface{}) *T {
var t T
_ = gconv.Struct(v, &t)
return &t
}

View File

@@ -1,69 +0,0 @@
package util
import (
"encoding/json"
"fmt"
)
// ValidatePromptResult 完整的校验逻辑
func ValidatePromptResult(raw map[string]any, requestMapping map[string]any) error {
contentStr, ok := raw["content"].(string)
if !ok || contentStr == "" {
return fmt.Errorf("content 字段为空或不是字符串")
}
var rounds []map[string]any
if err := json.Unmarshal([]byte(contentStr), &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 对 rounds 中的每一个元素进行结构校验
for i, round := range rounds {
if err := validateStructure(requestMapping, round); err != nil {
return fmt.Errorf("rounds[%d] 结构校验失败: %w", i, err)
}
}
return nil
}
// validateStructure 递归校验 actual 是否包含 expected 定义的所有字段路径
func validateStructure(expected any, actual any) error {
switch exp := expected.(type) {
case map[string]any:
act, ok := actual.(map[string]any)
if !ok {
return fmt.Errorf("期望对象,实际类型 %T", actual)
}
for key, expVal := range exp {
actVal, exists := act[key]
if !exists {
return fmt.Errorf("缺少字段: %s", key)
}
if err := validateStructure(expVal, actVal); err != nil {
return fmt.Errorf("%s: %w", key, err)
}
}
return nil
case []any:
act, ok := actual.([]any)
if !ok {
return fmt.Errorf("期望数组,实际类型 %T", actual)
}
if len(exp) == 0 {
return nil // 空数组模板,只校验类型
}
// 用第一个元素的结构去校验每个实际元素
for i, actItem := range act {
if err := validateStructure(exp[0], actItem); err != nil {
return fmt.Errorf("[%d]: %w", i, err)
}
}
return nil
default:
// 基本类型,不校验具体值,只检查存在
return nil
}
}

151
common/util/mapping.go Normal file
View File

@@ -0,0 +1,151 @@
package util
import (
"fmt"
"model-gateway/model/entity"
"net/url"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentKey := ""
for k := range model.ResponseBody {
contentKey = k
break
}
contentStr, ok := raw[contentKey].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
}
// 2) 解析 content 为 JSON 数组
var rounds []map[string]any
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
for i, round := range rounds {
for path, defaultValue := range requestMapping {
if !g.IsEmpty(defaultValue) {
continue
}
if gjson.New(round).Get(path).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
}
}
}
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
// 从 payload 取对应路径的值
val := gjson.New(payload).Get(path)
if !val.IsNil() {
// payload 有值,用它
_ = jsonObj.Set(path, val.Val())
} else if !g.IsEmpty(defaultValue) {
// payload 没值,用默认值
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
if len(mapping) == 0 {
return responseBytes, nil
}
responseJson := gjson.New(responseBytes)
resultJson := gjson.New("{}")
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
val := responseJson.Get(path)
if val.IsNil() {
continue
}
resultJson.Set(standardField, val.Val())
}
return []byte(resultJson.String()), nil
}
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func ParseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
// PayloadToQuery 将 payload 转为 url.Values
func PayloadToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}

140
common/util/network.go Normal file
View File

@@ -0,0 +1,140 @@
package util
import (
"context"
"net"
"strings"
"github.com/gogf/gf/v2/frame/g"
)
// GetLocalIP 获取本机有效的局域网 IPv4 地址
func GetLocalIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return "127.0.0.1"
}
var validIPs []string
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok {
continue
}
ip := ipnet.IP
if isIPValid(ip) {
validIPs = append(validIPs, ip.String())
}
}
// 优先返回非 169.254.x.x 的 IP
for _, ip := range validIPs {
if !strings.HasPrefix(ip, "169.254.") {
return ip
}
}
// 其次返回 169.254.x.x最后的选择
if len(validIPs) > 0 {
return validIPs[0]
}
return "127.0.0.1"
}
// isIPValid 判断 IP 是否有效
func isIPValid(ip net.IP) bool {
// 不是 loopback (127.0.0.1)
if ip.IsLoopback() {
return false
}
// 是 IPv4
if ip.To4() == nil {
return false
}
// 不是链路本地地址 (169.254.0.0/16)
if ip[0] == 169 && ip[1] == 254 {
return false
}
// 不是组播地址
if ip.IsMulticast() {
return false
}
// 不是未指定地址 (0.0.0.0)
if ip.IsUnspecified() {
return false
}
return true
}
// GetLocalAddress 获取局域网地址IP:端口)
func GetLocalAddress(ctx context.Context) string {
ip := GetLocalIP()
port := GetServerPort(ctx)
if port == "80" || port == "443" {
return ip
}
return ip + ":" + port
}
// GetSchemaFromRequest 从当前请求中获取协议http/https
func GetSchemaFromRequest(ctx context.Context) string {
r := g.RequestFromCtx(ctx)
if r == nil {
return "http"
}
// 1. 代理场景X-Forwarded-Proto
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
return proto
}
// 2. 代理场景X-Forwarded-Scheme
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
return proto
}
// 3. TLS 连接(直接 HTTPS
if r.TLS != nil {
return "https"
}
// 4. 默认 HTTP这行很重要
return "http" // ← 确保有这行
}
// GetLocalBaseURL 获取局域网基础 URL动态协议 + IP + 端口)
func GetLocalBaseURL(ctx context.Context) string {
schema := GetSchemaFromRequest(ctx)
addr := GetLocalAddress(ctx)
return schema + "://" + addr
}
// GetCallbackURL 获取回调地址(完整 URL
func GetCallbackURL(ctx context.Context, path string) string {
baseURL := GetLocalBaseURL(ctx)
// 确保 path 以 / 开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return baseURL + path
}
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
// address 格式如 ":3009",去掉冒号
if strings.HasPrefix(address, ":") {
return address[1:]
}
return "8080"
}

View File

@@ -9,6 +9,7 @@ const (
ImageSubTypeImageToImage = 202 // 图片模型-图生图
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图
ModelTypeAudio = 300 // 音频模型
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
@@ -40,6 +41,7 @@ var ModelTypeName = map[int]string{
ImageSubTypeImageToImage: "图片模型-图生图",
ImageSubTypeImageEdit: "图片模型-图片编辑",
ImageSubTypeImageVariation: "图片模型-图片变体",
ImageSubTypeImageTextToImage: "图片模型-图文生图",
ModelTypeAudio: "音频模型",
AudioSubTypeTextToSpeech: "音频模型-文生音",

View File

@@ -2,9 +2,9 @@ package controller
import (
"context"
"model-gateway/model/dto"
"model-gateway/service"
modelService "model-gateway/service/model"
"model-gateway/service/queue"
)
type model struct{}
@@ -14,53 +14,53 @@ var Model = new(model)
// CreateModel 添加配置
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
return service.Model.Create(ctx, req)
return modelService.Model.Create(ctx, req)
}
// UpdateModel 更改配置
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
err = service.Model.Update(ctx, req)
err = modelService.Model.Update(ctx, req)
return
}
// DeleteModel 删除配置
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
err = service.Model.Delete(ctx, req)
err = modelService.Model.Delete(ctx, req)
return
}
// GetModel 获取配置详情
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
return service.Model.Get(ctx, req)
return modelService.Model.Get(ctx, req)
}
// ListModel 配置列表
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
return service.Model.List(ctx, req)
return modelService.Model.List(ctx, req)
}
// AutoTune 动态调参(由上层定时任务每小时触发一次)
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
return service.AutoTune(ctx, req)
return queue.AutoTune(ctx, req)
}
// ListType 模型类型列表
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
return service.GetModelTypesFromConfig()
return modelService.GetModelTypesFromConfig()
}
// ListOperator 运营商列表
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
return service.GetOperatorList()
return modelService.GetOperatorList()
}
// UpdateChatModel 更新是否为聊天模型
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
err = service.Model.UpdateChatModel(ctx, req)
err = modelService.Model.UpdateChatModel(ctx, req)
return
}
// GetIsChatModel 获取当前会话模型
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
return service.Model.GetIsChatModel(ctx)
return modelService.Model.GetIsChatModel(ctx)
}

View File

@@ -2,9 +2,9 @@ package controller
import (
"context"
statService "model-gateway/service/stat"
"model-gateway/model/dto"
"model-gateway/service"
)
type stat struct{}
@@ -14,5 +14,5 @@ var Stat = new(stat)
// ListModelStat 统计列表
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
return service.Stat.List(ctx, req)
return statService.Stat.List(ctx, req)
}

View File

@@ -2,9 +2,10 @@ package controller
import (
"context"
"model-gateway/service/job"
taskService "model-gateway/service/task"
"model-gateway/model/dto"
"model-gateway/service"
)
type task struct{}
@@ -14,30 +15,30 @@ var Task = new(task)
// CreateTask 根据 modelName 创建异步任务,返回 taskId
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
return service.Task.Create(ctx, req)
return taskService.Task.Create(ctx, req)
}
// GetTaskResult 获取任务结果(只返回 oss 地址 + state
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
return service.Task.GetResult(ctx, req.TaskID)
return taskService.Task.GetResult(ctx, req.TaskID)
}
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
return service.Task.GetBatch(ctx, req)
return taskService.Task.GetBatch(ctx, req)
}
// ListTask 任务列表分页查询
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
return service.Task.List(ctx, req)
return taskService.Task.List(ctx, req)
}
// RunWork 手动触发一次 worker由上层定时任务调用
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
return service.AsyncWorker.RunOnce(ctx, req)
return taskService.AsyncWorker.RunOnce(ctx, req)
}
// CleanWork 手动触发一次 cleaner由上层定时任务调用
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
return service.Cleaner.RunOnce(ctx)
return job.Cleaner.RunOnce(ctx)
}

View File

@@ -5,6 +5,7 @@ import (
"model-gateway/consts/public"
"model-gateway/model/dto"
"model-gateway/model/entity"
"strconv"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
@@ -90,22 +91,28 @@ func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchMode
// GetByCreatorAndPlatform 按创建者、平台获取
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
// 基础 SQL
sql := `
SELECT DISTINCT ON (model_name) *
FROM asynch_models
WHERE deleted_at IS NULL
AND (? = '' OR model_name LIKE ?)
AND (? = 0 OR model_type = ?)
`
args := []any{
req.ModelName, "%" + req.ModelName + "%",
req.ModelType, req.ModelType,
}
// modelType: 传 6 模糊匹配 6%
if req.ModelType > 0 {
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
sql += ` AND model_type::text LIKE ? `
args = append(args, prefix+"%")
}
if !g.IsEmpty(req.IsPrivate) {
sql += ` AND is_private = ? `
args = append(args, req.IsPrivate)
}
if req.IsOwner != nil && *req.IsOwner == 0 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
@@ -114,9 +121,7 @@ WHERE deleted_at IS NULL
} else {
sql += ` AND creator = ? AND is_owner = ? `
}
args = append(args, req.Creator)
args = append(args, req.IsOwner)
args = append(args, req.Creator, req.IsOwner)
} else if req.IsOwner != nil && *req.IsOwner == 1 {
if req.Enabled != nil && *req.Enabled == 1 {
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
@@ -125,11 +130,9 @@ WHERE deleted_at IS NULL
} else {
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
}
args = append(args, req.Creator)
args = append(args, req.IsOwner)
args = append(args, req.Creator, req.IsOwner)
}
// 最后拼接排序
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)

View File

@@ -3,13 +3,14 @@ package main
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/job"
"model-gateway/service/task"
"os"
"os/signal"
"syscall"
"time"
"model-gateway/controller"
"model-gateway/service"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
@@ -62,7 +63,7 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
if _, err := task.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
BatchSize: batchSize,
Goroutines: goroutines,
}); err != nil {
@@ -87,7 +88,7 @@ func startAutoRunner(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
_, _ = service.Cleaner.RunOnce(ctx)
_, _ = job.Cleaner.RunOnce(ctx)
}
}
}()

View File

@@ -24,7 +24,7 @@ type CreateModelReq struct {
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
Form map[string]any `p:"form" json:"form" dc:"动态表单配置JSON用于前端渲染配置项"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置JSON用于前端渲染配置项"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"`
@@ -52,7 +52,7 @@ type UpdateModelReq struct {
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式GET/POST可选更新"`
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
Form map[string]any `p:"form" json:"form" dc:"动态表单配置JSON可选更新"`
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置JSON可选更新"`
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
@@ -166,3 +166,20 @@ type GetIsChatModelReq struct {
type GetIsChatModelRes struct {
Model any `json:"model" dc:"模型详情"`
}
// NodeFormField 节点表单
type NodeFormField struct {
Value any `json:"value" dc:"字段值"`
Field string `json:"field" dc:"字段标识"`
Label string `json:"label" dc:"字段标签"`
Type string `json:"type" dc:"字段类型"`
Required bool `json:"required" dc:"是否必填"`
Default any `json:"default,omitempty" dc:"默认值"`
Options []SelectOption `json:"options" dc:"下拉选项列表"`
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
}
type SelectOption struct {
Label string `json:"label" dc:"选项标签"`
Value string `json:"value" dc:"选项值"`
}

View File

@@ -74,7 +74,7 @@ type AsynchModel struct {
BaseURL string `orm:"base_url" json:"baseUrl"`
HttpMethod string `orm:"http_method" json:"httpMethod"`
HeadMsg string `orm:"head_msg" json:"headMsg"`
Form map[string]any `orm:"form_json" json:"form"`
Form []map[string]any `orm:"form_json" json:"form"`
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
ResponseBody map[string]any `orm:"response_body" json:"responseBody"`

View File

@@ -1 +0,0 @@
package service

View File

@@ -1,8 +1,9 @@
package service
package job
import (
"context"
"model-gateway/model/dto"
"model-gateway/service/queue"
"os"
"time"
@@ -20,32 +21,32 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
} else {
for _, t := range expired {
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
}
// 2) 超时任务标失败
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
} else {
for _, t := range list {
t.ErrorMsg = "任务超时自动失败"
_ = dao.Task.UpdateFailedGlobal(ctx, t)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
}
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
} else {
for _, t := range retryable {
// 失败任务重新入队state=3 -> 0先严格占用 queue_limit slot占用失败则留在失败态下一轮再尝试
@@ -54,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
if err != nil || m == nil {
continue
}
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
if limit > 0 {
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
if !ok {
continue
}
@@ -76,21 +77,21 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
}
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
}
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
} else {
for _, t := range exhausted {
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,

View File

@@ -0,0 +1,254 @@
package model
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
// Create 创建模型
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
// 1如果设为会话模型先把该用户旧会话模型取消
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
if err := s.clearUserChatModel(ctx); err != nil {
return nil, err
}
}
// 2判断是否超管决定 isOwner
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 3入库
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
// Update 更新模型配置
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
// 1会话模型唯一性校验
if req.IsChatModel != nil && *req.IsChatModel == 1 {
if err := s.checkChatModelUnique(ctx); err != nil {
return err
}
}
// 2超管创建/普通用户更新
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// 3跨租户判断超管的模型不允许直接修改走插入新记录
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
return err
}
// Delete 删除模型
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
// Get 获取模型详情
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
if g.IsEmpty(req.ID) {
req.Creator = user.UserName
}
modelReq := new(entity.AsynchModel)
err = gconv.Struct(req, modelReq)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, modelReq)
if err != nil {
return nil, err
}
return &dto.GetModelRes{
Model: model,
}, nil
}
// List 获取模型列表
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
// 1判断超管
req.IsOwner = gconv.PtrInt(1)
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
req.IsOwner = gconv.PtrInt(0)
}
// 2获取当前用户
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
req.Creator = user.UserName
// 3查询
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return nil, err
}
return &dto.ListModelRes{List: models, Total: total}, nil
}
// UpdateChatModel 设置会话模型
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 1校验新模型存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil || newModel == nil {
return errors.New("新会话模型不存在")
}
// 2获取当前用户的会话模型
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: new(1),
})
if err != nil {
return err
}
// 3事务取消旧的 + 设置新的
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
})
}
// GetIsChatModel 获取当前用户会话模型
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: new(1),
})
if err != nil || model == nil {
return nil, err
}
return &dto.GetIsChatModelRes{Model: model}, nil
}
// ==================== 辅助方法 ====================
// clearUserChatModel 清除当前用户旧会话模型
func (s *modelService) clearUserChatModel(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: new(1),
})
if err != nil || model == nil {
return nil
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
return err
}
// checkChatModelUnique 校验用户是否已有会话模型
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
IsChatModel: new(1),
})
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型")
}
return nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
return &dto.TypeItem{
Type: types,
}, nil
}
// GetOperatorList 获取运营商列表
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
return &dto.ListOperatorRes{
List: public.OperatorList,
}, nil
}

View File

@@ -1,469 +0,0 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"strings"
"time"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func parseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
// ============ 新增:请求参数映射 ============
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
if err != nil {
return nil, fmt.Errorf("请求参数映射失败: %w", err)
}
url := strings.TrimRight(m.BaseURL, "/")
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 先注入模型配置 head_msg静态头部适合公共模型固定 API Key
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
// 最后注入动态 modelKey允许覆盖/补充静态 head_msg适合按请求动态传密钥。
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// ============ 新增:响应参数映射 ============
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
if err != nil {
// 响应映射失败不阻塞,返回原始数据
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
// =========================================
return mappedResponse, nil
}
//// InvokeModel 调用模型服务,返回二进制结果
//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
//}
// ============================================
// 映射相关函数
// ============================================
// mapRequestPayload 将标准请求映射为模型特定格式
func mapRequestPayload(mappingAny any, payload any) (any, error) {
// 1. 解析请求映射配置值是any类型支持bool、number等
mapping, err := parseRequestMapping(mappingAny)
if err != nil {
return nil, err
}
// 如果没有映射配置直接返回原始payload
if len(mapping) == 0 {
return payload, nil
}
// 2. 将payload转为map
var payloadMap map[string]any
switch v := payload.(type) {
case map[string]any:
payloadMap = v
case []map[string]any:
// 如果传进来的是纯messages数组包装成标准格式
payloadMap = map[string]any{
"messages": v,
}
default:
// 通过JSON转换
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化payload失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
return nil, fmt.Errorf("反序列化payload失败: %w", err)
}
}
// 3. 用数据库固定参数覆盖/补充
for key, value := range mapping {
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
payloadMap[key] = value
}
}
return payloadMap, nil
}
// mapResponsePayload 将模型响应映射为标准格式
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
mapping, err := parseResponseMapping(mappingAny)
if err != nil {
return nil, err
}
if len(mapping) == 0 {
return responseBytes, nil
}
responseStr := string(responseBytes)
resultStr := `{}`
for standardField, modelPath := range mapping {
value := gjson.Get(responseStr, modelPath)
if !value.Exists() {
continue
}
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
if err != nil {
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
}
}
return []byte(resultStr), nil
}
func parseRequestMapping(mappingAny any) (map[string]any, error) {
if mappingAny == nil {
return nil, nil
}
result := make(map[string]any)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
// 尝试转成 map
if m := v.Map(); m != nil {
for k, val := range m {
result[k] = val
}
return result, nil
}
// 尝试转成 string
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
return result, nil
}
return nil, nil
// =======================================================
case map[string]interface{}:
result = v
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &result); err != nil {
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("解析映射配置失败: %w", err)
}
}
return result, nil
}
// parseResponseMapping 解析响应映射配置
// 返回值类型为 map[string]string值都是JSON路径字符串
func parseResponseMapping(mappingAny any) (map[string]string, error) {
if mappingAny == nil {
return nil, nil
}
mapping := make(map[string]string)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
if m := v.Map(); m != nil {
for k, val := range m {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
return mapping, nil
}
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
return mapping, nil
}
return nil, nil
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
case map[string]interface{}:
// 数据库JSONB直接返回的map
for k, val := range v {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
}
}
return mapping, nil
}
// isEmptyValue 判断值是否为空
func isEmptyValue(v any) bool {
if v == nil {
return true
}
switch val := v.(type) {
case string:
return val == ""
case []any:
return len(val) == 0
case map[string]any:
return len(val) == 0
default:
return false
}
}

View File

@@ -1,389 +0,0 @@
package service
import (
"context"
"errors"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
// 获取当前用户会话模型
var model *entity.AsynchModel
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
// 如果有会话模型,那就改变为 0
if model != nil {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return nil, err
}
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
ExtendMapping: req.ExtendMapping,
QueryConfig: req.QueryConfig,
})
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型,不能创建")
}
}
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return err
}
if admin {
req.IsOwner = gconv.PtrInt(0)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
ExtendMapping: req.ExtendMapping,
QueryConfig: req.QueryConfig,
})
if err != nil {
return err
}
return nil
}
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建否则更新
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if model.TenantId == 1 {
insertDto := new(dto.CreateModelReq)
err = gconv.Struct(req, insertDto)
if err != nil {
return err
}
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
ExtendMapping: req.ExtendMapping,
QueryConfig: req.QueryConfig,
})
return err
}
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
ExtendMapping: req.ExtendMapping,
QueryConfig: req.QueryConfig,
})
return err
}
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Id: req.ID,
Creator: user.UserName,
},
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
return &dto.GetModelRes{
Model: model,
}, nil
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
var models []*entity.AsynchModel
req.IsOwner = gconv.PtrInt(1)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
req.Creator = user.UserName
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return
}
return &dto.ListModelRes{
List: models,
Total: total,
}, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
return &dto.TypeItem{
Type: types,
}, nil
}
// GetOperatorList 获取运营商列表
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
return &dto.ListOperatorRes{
List: public.OperatorList,
}, nil
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 校验新会话模型是否存在
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 获取当前用户会话模型
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return err
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
if !g.IsEmpty(currentModel) {
if currentModel.ModelType != public.ModelTypeInference {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
return err
}
}
}
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
})
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{
Creator: user.UserName,
},
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
return &dto.GetIsChatModelRes{
Model: model,
}, nil
}

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"
@@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int {
}
return v
}

View File

@@ -1,4 +1,4 @@
package service
package queue
import (
"context"
@@ -34,7 +34,8 @@ end
return 1
`
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
// AcquireSemaphore 获取并发令牌
func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
if max <= 0 {
// 不限制
return true, nil
@@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64
return gconv.Int(r) == 1, nil
}
func releaseSemaphore(ctx context.Context, key string) error {
// ReleaseSemaphore 释放并发令牌
func ReleaseSemaphore(ctx context.Context, key string) error {
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
return err
}

View File

@@ -1,4 +1,4 @@
package service
package stat
import (
"context"

View File

@@ -1,9 +1,10 @@
package service
package task
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/service/queue"
"time"
"model-gateway/dao"
@@ -20,10 +21,11 @@ var Task = &taskService{}
type taskService struct{}
// Create 创建任务
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
startAt := time.Now()
// 固化 token/user 等信息
ctx = util.AsyncCtx(ctx)
taskID := uuid.NewString()
// 1) 检查模型配置
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
@@ -35,11 +37,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
return nil, errors.New("模型不存在或未启用")
}
taskID := uuid.NewString()
// 2) 排队上限严格控制Redis 原子闸门)
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
if limit > 0 {
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
if err != nil {
return nil, err
}
@@ -48,13 +49,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
}
}
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
// 3) 插入任务记录
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": util.ForwardHeaders(ctx),
}
t := &entity.AsynchTask{
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
@@ -64,21 +64,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
}
_, err = dao.Task.Insert(ctx, t)
})
if err != nil {
// 入库失败:回滚闸门占位
ReleaseQueueSlot(ctx, req.ModelName, taskID)
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
return nil, err
}
// 3) 写操作日志(尽量不影响主流程,失败忽略)
// 4) 写操作日志(不影响主流程,失败忽略)
ip := ""
ua := ""
apiPath := "/task/createTask"
httpMethod := "POST"
if r := g.RequestFromCtx(ctx); r != nil {
ip = r.GetClientIp()
ip = util.GetLocalIP()
ua = r.UserAgent()
apiPath = r.URL.Path
httpMethod = r.Method
@@ -101,70 +100,68 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
},
})
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req)
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 这样不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
// - 不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
if taskID == "" {
return
}
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
if interval <= 0 {
interval = 5
}
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
defer cancel()
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
tryRun := func() bool {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
return true
}
switch t.State {
case 0:
//RunByTaskID 尝试执行任务
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
return true
}
}
// 先立即尝试一次
// 立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-ctx.Done():
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
case <-pollCtx.Done():
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
@@ -174,6 +171,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
}
// GetResult 获取任务结果
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
@@ -244,6 +242,7 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
return &dto.GetTaskBatchRes{List: items}, nil
}
// List 获取任务列表
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil {

View File

@@ -1,12 +1,17 @@
package service
package task
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"model-gateway/common/util"
"model-gateway/model/dto"
"model-gateway/service/gateway"
"model-gateway/service/queue"
"net/http"
"os"
"path/filepath"
"strings"
@@ -56,7 +61,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
if e != nil {
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
_ = dao.Task.UpdateFailedGlobal(ctx, task)
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
}
done <- struct{}{}
})
@@ -100,8 +105,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
// 2) 分布式并发控制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
if err != nil {
w.failTask(ctx, t, err.Error())
return
@@ -111,7 +116,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() { _ = releaseSemaphore(ctx, semKey) }()
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
// 3) request_payload 校验
if payload == nil {
@@ -146,31 +151,32 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
}
// 6) 解析校验(可重试,失败重新调模型)
if req.BuildType == 1 {
for attempt := 0; attempt <= maxRetry; attempt++ {
if attempt > 0 {
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
}
err = util.ValidatePromptResult(textResult, model.RequestMapping)
if err == nil {
break
}
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, err)
if attempt == maxRetry {
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
return
}
// 重新调模型
newResult, modelErr := w.callModel(ctx, t, model, payload)
if modelErr != nil {
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
t.TaskID, attempt, maxRetry, modelErr)
continue
}
textResult = newResult
}
}
//if req.BuildType == 1 {
// for attempt := 0; attempt <= maxRetry; attempt++ {
// if attempt > 0 {
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
// }
// // 6.1) 校验数据
// err = util.ValidatePromptResult(textResult, model)
// if err == nil {
// break
// }
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, err)
// if attempt == maxRetry {
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
// return
// }
// // 6.2) 重新调模型
// newResult, modelErr := w.callModel(ctx, t, model, payload)
// if modelErr != nil {
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
// t.TaskID, attempt, maxRetry, modelErr)
// continue
// }
// textResult = newResult
// }
//}
// 7) 成功回调
t.State = 2
@@ -185,7 +191,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
return
}
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
if req.EpicycleId != 0 {
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
@@ -198,29 +204,29 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
// callModel 调用模型 + 检测文件类型 + 保存临时文件
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
var data []byte
var contentType, ext, textResult string
var err error
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = os.ReadFile(t.TmpFile)
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
data, err = os.ReadFile(task.TmpFile)
if err != nil || len(data) == 0 {
data = nil
}
}
if data == nil {
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
if err != nil {
return nil, err
}
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
if tmpErr == nil && tmpPath != "" {
t.TmpFile = tmpPath
t.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
task.TmpFile = tmpPath
task.Phase = 1
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
}
}
@@ -228,10 +234,138 @@ func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *en
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
return gjson.New(textResult).Map(), nil
}
// InvokeModel 调用模型服务,返回二进制结果
// modelKey 用于覆盖/补充模型配置 head_msg例如每次请求携带不同的 X-API-Key
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
// 1请求参数映射将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
// 2构建请求 URL 和超时
baseURL := strings.TrimRight(model.BaseURL, "/")
timeout := time.Duration(model.TimeoutSeconds) * time.Second
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
// 3构建 HTTP 请求
var req *http.Request
switch method {
case http.MethodGet:
q, err := util.PayloadToQuery(payload)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(baseURL, "?") {
baseURL = baseURL + "&" + q.Encode()
} else {
baseURL = baseURL + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
default:
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
}
// 4注入请求头先模型静态配置再动态 modelKey后者可覆盖前者
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
req.Header.Set(hk, hv)
}
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
// 5发送请求
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 6读取响应体
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 7检查 HTTP 状态码
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
// 8响应参数映射
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
if err != nil {
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
return mappedResponse, nil
}
// // InvokeModel 调用模型服务,返回二进制结果
//
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
// if m == nil || m.BaseURL == "" {
// return nil, fmt.Errorf("模型配置不完整")
// }
// // 请求参数映射
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
// if err != nil {
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
// }
// // 合并请求头
// headers := util.ForwardHeaders(ctx)
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
// headers[hk] = hv
// }
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
// headers[hk] = hv
// }
//
// // 设置超时
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
// if timeout <= 0 {
// timeout = 600 * time.Second
// }
// ctx, cancel := context.WithTimeout(ctx, timeout)
// defer cancel()
//
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
// if method == "" {
// method = http.MethodPost
// }
//
// var respBytes []byte
//
// switch method {
// case http.MethodGet:
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// default:
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
// }
// if err != nil {
// return nil, err
// }
// // 响应参数映射
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
// if err != nil {
// g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
// return respBytes, nil
// }
// return mappedResponse, nil
// }
// uploadOSS 从临时文件上传 OSS
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
data, err := os.ReadFile(t.TmpFile)
@@ -247,7 +381,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
t.State = 3
t.ErrorMsg = errMsg
_ = dao.Task.UpdateFailedGlobal(ctx, t)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
}