diff --git a/common/util/convert.go b/common/util/convert.go new file mode 100644 index 0000000..4148bdc --- /dev/null +++ b/common/util/convert.go @@ -0,0 +1,10 @@ +package util + +import "github.com/gogf/gf/v2/util/gconv" + +// ConvertTo 转换为指定类型 +func ConvertTo[T any](v interface{}) *T { + var t T + _ = gconv.Struct(v, &t) + return &t +} diff --git a/common/util/json.go b/common/util/json.go deleted file mode 100644 index c841e73..0000000 --- a/common/util/json.go +++ /dev/null @@ -1,69 +0,0 @@ -package util - -import ( - "encoding/json" - "fmt" -) - -// ValidatePromptResult 完整的校验逻辑 -func ValidatePromptResult(raw map[string]any, requestMapping map[string]any) error { - contentStr, ok := raw["content"].(string) - if !ok || contentStr == "" { - return fmt.Errorf("content 字段为空或不是字符串") - } - - var rounds []map[string]any - if err := json.Unmarshal([]byte(contentStr), &rounds); err != nil { - return fmt.Errorf("解析 content JSON 数组失败: %w", err) - } - if len(rounds) == 0 { - return fmt.Errorf("content 数组为空") - } - - // 对 rounds 中的每一个元素进行结构校验 - for i, round := range rounds { - if err := validateStructure(requestMapping, round); err != nil { - return fmt.Errorf("rounds[%d] 结构校验失败: %w", i, err) - } - } - return nil -} - -// validateStructure 递归校验 actual 是否包含 expected 定义的所有字段路径 -func validateStructure(expected any, actual any) error { - switch exp := expected.(type) { - case map[string]any: - act, ok := actual.(map[string]any) - if !ok { - return fmt.Errorf("期望对象,实际类型 %T", actual) - } - for key, expVal := range exp { - actVal, exists := act[key] - if !exists { - return fmt.Errorf("缺少字段: %s", key) - } - if err := validateStructure(expVal, actVal); err != nil { - return fmt.Errorf("%s: %w", key, err) - } - } - return nil - case []any: - act, ok := actual.([]any) - if !ok { - return fmt.Errorf("期望数组,实际类型 %T", actual) - } - if len(exp) == 0 { - return nil // 空数组模板,只校验类型 - } - // 用第一个元素的结构去校验每个实际元素 - for i, actItem := range act { - if err := validateStructure(exp[0], actItem); err != nil { - return fmt.Errorf("[%d]: %w", i, err) - } - } - return nil - default: - // 基本类型,不校验具体值,只检查存在 - return nil - } -} diff --git a/common/util/mapping.go b/common/util/mapping.go new file mode 100644 index 0000000..01b0cbe --- /dev/null +++ b/common/util/mapping.go @@ -0,0 +1,151 @@ +package util + +import ( + "fmt" + "model-gateway/model/entity" + "net/url" + "strings" + + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性 +// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段 +func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error { + // 1) 获取校验配置,并取值 + requestMapping := model.RequestMapping + contentKey := "" + for k := range model.ResponseBody { + contentKey = k + break + } + contentStr, ok := raw[contentKey].(string) + if !ok || contentStr == "" { + return fmt.Errorf("%s 字段为空或不是字符串", contentKey) + } + + // 2) 解析 content 为 JSON 数组 + var rounds []map[string]any + if err := gjson.DecodeTo(contentStr, &rounds); err != nil { + return fmt.Errorf("解析 content JSON 数组失败: %w", err) + } + if len(rounds) == 0 { + return fmt.Errorf("content 数组为空") + } + + // 3) 逐条校验:只检查默认值为空的必填字段是否存在 + for i, round := range rounds { + for path, defaultValue := range requestMapping { + if !g.IsEmpty(defaultValue) { + continue + } + if gjson.New(round).Get(path).IsNil() { + return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path) + } + } + } + return nil +} + +// ReverseMap 映射 payload 到 mapping +func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any { + jsonObj := gjson.New("{}") + for path, defaultValue := range mapping { + // 从 payload 取对应路径的值 + val := gjson.New(payload).Get(path) + if !val.IsNil() { + // payload 有值,用它 + _ = jsonObj.Set(path, val.Val()) + } else if !g.IsEmpty(defaultValue) { + // payload 没值,用默认值 + _ = jsonObj.Set(path, defaultValue) + } + } + return jsonObj.Map() +} + +// MapResponsePayload 映射模型响应为标准格式 +func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) { + if len(mapping) == 0 { + return responseBytes, nil + } + + responseJson := gjson.New(responseBytes) + resultJson := gjson.New("{}") + + for standardField, modelPath := range mapping { + path := gconv.String(modelPath) + if path == "" { + continue + } + val := responseJson.Get(path) + if val.IsNil() { + continue + } + resultJson.Set(standardField, val.Val()) + } + + return []byte(resultJson.String()), nil +} + +// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: +// 示例: +// - X-API-Key:qwen3-tts-key,operation:true,count:123 +// - X-API-Key:"qwen3-tts-key",operation:"true" +// +// 说明: +// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 +// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 +func ParseHeadMsgHeaders(headMsg string) map[string]string { + headMsg = strings.TrimSpace(headMsg) + if headMsg == "" { + return nil + } + out := map[string]string{} + parts := strings.Split(headMsg, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) + if strings.Contains(p, ":") { + kv := strings.SplitN(p, ":", 2) + k := strings.TrimSpace(kv[0]) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"") + if k != "" && v != "" { + out[k] = v + } + continue + } + if strings.Contains(p, "=") { + kv := strings.SplitN(p, "=", 2) + k := strings.TrimSpace(kv[0]) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"") + if k != "" && v != "" { + out[k] = v + } + continue + } + } + if len(out) == 0 { + return nil + } + return out +} + +// PayloadToQuery 将 payload 转为 url.Values +func PayloadToQuery(payload map[string]any) (url.Values, error) { + q := url.Values{} + for k, v := range payload { + if v == nil { + continue + } + q.Set(k, gconv.String(v)) + } + return q, nil +} diff --git a/common/util/network.go b/common/util/network.go new file mode 100644 index 0000000..ffc98d6 --- /dev/null +++ b/common/util/network.go @@ -0,0 +1,140 @@ +package util + +import ( + "context" + "net" + "strings" + + "github.com/gogf/gf/v2/frame/g" +) + +// GetLocalIP 获取本机有效的局域网 IPv4 地址 +func GetLocalIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "127.0.0.1" + } + + var validIPs []string + + for _, addr := range addrs { + ipnet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + ip := ipnet.IP + + if isIPValid(ip) { + validIPs = append(validIPs, ip.String()) + } + } + + // 优先返回非 169.254.x.x 的 IP + for _, ip := range validIPs { + if !strings.HasPrefix(ip, "169.254.") { + return ip + } + } + + // 其次返回 169.254.x.x(最后的选择) + if len(validIPs) > 0 { + return validIPs[0] + } + + return "127.0.0.1" +} + +// isIPValid 判断 IP 是否有效 +func isIPValid(ip net.IP) bool { + // 不是 loopback (127.0.0.1) + if ip.IsLoopback() { + return false + } + + // 是 IPv4 + if ip.To4() == nil { + return false + } + + // 不是链路本地地址 (169.254.0.0/16) + if ip[0] == 169 && ip[1] == 254 { + return false + } + + // 不是组播地址 + if ip.IsMulticast() { + return false + } + + // 不是未指定地址 (0.0.0.0) + if ip.IsUnspecified() { + return false + } + + return true +} + +// GetLocalAddress 获取局域网地址(IP:端口) +func GetLocalAddress(ctx context.Context) string { + ip := GetLocalIP() + port := GetServerPort(ctx) + + if port == "80" || port == "443" { + return ip + } + return ip + ":" + port +} + +// GetSchemaFromRequest 从当前请求中获取协议(http/https) +func GetSchemaFromRequest(ctx context.Context) string { + r := g.RequestFromCtx(ctx) + if r == nil { + return "http" + } + + // 1. 代理场景:X-Forwarded-Proto + if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" { + return proto + } + + // 2. 代理场景:X-Forwarded-Scheme + if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" { + return proto + } + + // 3. TLS 连接(直接 HTTPS) + if r.TLS != nil { + return "https" + } + + // 4. 默认 HTTP(这行很重要!) + return "http" // ← 确保有这行 +} + +// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口) +func GetLocalBaseURL(ctx context.Context) string { + schema := GetSchemaFromRequest(ctx) + addr := GetLocalAddress(ctx) + return schema + "://" + addr +} + +// GetCallbackURL 获取回调地址(完整 URL) +func GetCallbackURL(ctx context.Context, path string) string { + baseURL := GetLocalBaseURL(ctx) + // 确保 path 以 / 开头 + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return baseURL + path +} + +// GetServerPort 从配置获取服务端口 +func GetServerPort(ctx context.Context) string { + address := g.Cfg().MustGet(ctx, "server.address", ":8080").String() + // address 格式如 ":3009",去掉冒号 + if strings.HasPrefix(address, ":") { + return address[1:] + } + return "8080" +} diff --git a/consts/public/public.go b/consts/public/public.go index 26a6dd4..6349997 100644 --- a/consts/public/public.go +++ b/consts/public/public.go @@ -4,11 +4,12 @@ package public const ( ModelTypeInference = 100 // 推理模型 - ModelTypeImage = 200 // 图片模型 - ImageSubTypeTextToImage = 201 // 图片模型-文生图 - ImageSubTypeImageToImage = 202 // 图片模型-图生图 - ImageSubTypeImageEdit = 203 // 图片模型-图片编辑 - ImageSubTypeImageVariation = 204 // 图片模型-图片变体 + ModelTypeImage = 200 // 图片模型 + ImageSubTypeTextToImage = 201 // 图片模型-文生图 + ImageSubTypeImageToImage = 202 // 图片模型-图生图 + ImageSubTypeImageEdit = 203 // 图片模型-图片编辑 + ImageSubTypeImageVariation = 204 // 图片模型-图片变体 + ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图 ModelTypeAudio = 300 // 音频模型 AudioSubTypeTextToSpeech = 301 // 音频模型-文生音 @@ -35,11 +36,12 @@ const ( var ModelTypeName = map[int]string{ ModelTypeInference: "推理模型", - ModelTypeImage: "图片模型", - ImageSubTypeTextToImage: "图片模型-文生图", - ImageSubTypeImageToImage: "图片模型-图生图", - ImageSubTypeImageEdit: "图片模型-图片编辑", - ImageSubTypeImageVariation: "图片模型-图片变体", + ModelTypeImage: "图片模型", + ImageSubTypeTextToImage: "图片模型-文生图", + ImageSubTypeImageToImage: "图片模型-图生图", + ImageSubTypeImageEdit: "图片模型-图片编辑", + ImageSubTypeImageVariation: "图片模型-图片变体", + ImageSubTypeImageTextToImage: "图片模型-图文生图", ModelTypeAudio: "音频模型", AudioSubTypeTextToSpeech: "音频模型-文生音", diff --git a/controller/model_controller.go b/controller/model_controller.go index 2411b04..e81a39f 100644 --- a/controller/model_controller.go +++ b/controller/model_controller.go @@ -2,9 +2,9 @@ package controller import ( "context" - "model-gateway/model/dto" - "model-gateway/service" + modelService "model-gateway/service/model" + "model-gateway/service/queue" ) type model struct{} @@ -14,53 +14,53 @@ var Model = new(model) // CreateModel 添加配置 func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { - return service.Model.Create(ctx, req) + return modelService.Model.Create(ctx, req) } // UpdateModel 更改配置 func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) { - err = service.Model.Update(ctx, req) + err = modelService.Model.Update(ctx, req) return } // DeleteModel 删除配置 func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) { - err = service.Model.Delete(ctx, req) + err = modelService.Model.Delete(ctx, req) return } // GetModel 获取配置详情 func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) { - return service.Model.Get(ctx, req) + return modelService.Model.Get(ctx, req) } // ListModel 配置列表 func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { - return service.Model.List(ctx, req) + return modelService.Model.List(ctx, req) } // AutoTune 动态调参(由上层定时任务每小时触发一次) func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) { - return service.AutoTune(ctx, req) + return queue.AutoTune(ctx, req) } // ListType 模型类型列表 func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) { - return service.GetModelTypesFromConfig() + return modelService.GetModelTypesFromConfig() } // ListOperator 运营商列表 func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) { - return service.GetOperatorList() + return modelService.GetOperatorList() } // UpdateChatModel 更新是否为聊天模型 func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) { - err = service.Model.UpdateChatModel(ctx, req) + err = modelService.Model.UpdateChatModel(ctx, req) return } // GetIsChatModel 获取当前会话模型 func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) { - return service.Model.GetIsChatModel(ctx) + return modelService.Model.GetIsChatModel(ctx) } diff --git a/controller/stat_controller.go b/controller/stat_controller.go index dba21ad..64325e4 100644 --- a/controller/stat_controller.go +++ b/controller/stat_controller.go @@ -2,9 +2,9 @@ package controller import ( "context" + statService "model-gateway/service/stat" "model-gateway/model/dto" - "model-gateway/service" ) type stat struct{} @@ -14,5 +14,5 @@ var Stat = new(stat) // ListModelStat 统计列表 func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) { - return service.Stat.List(ctx, req) + return statService.Stat.List(ctx, req) } diff --git a/controller/task_controller.go b/controller/task_controller.go index f62a7bb..44be51f 100644 --- a/controller/task_controller.go +++ b/controller/task_controller.go @@ -2,9 +2,10 @@ package controller import ( "context" + "model-gateway/service/job" + taskService "model-gateway/service/task" "model-gateway/model/dto" - "model-gateway/service" ) type task struct{} @@ -14,30 +15,30 @@ var Task = new(task) // CreateTask 根据 modelName 创建异步任务,返回 taskId func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { - return service.Task.Create(ctx, req) + return taskService.Task.Create(ctx, req) } // GetTaskResult 获取任务结果(只返回 oss 地址 + state) func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) { - return service.Task.GetResult(ctx, req.TaskID) + return taskService.Task.GetResult(ctx, req.TaskID) } // GetTaskBatch 批量查询任务(成功任务标记为已下载) func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) { - return service.Task.GetBatch(ctx, req) + return taskService.Task.GetBatch(ctx, req) } // ListTask 任务列表分页查询 func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { - return service.Task.List(ctx, req) + return taskService.Task.List(ctx, req) } // RunWork 手动触发一次 worker(由上层定时任务调用) func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) { - return service.AsyncWorker.RunOnce(ctx, req) + return taskService.AsyncWorker.RunOnce(ctx, req) } // CleanWork 手动触发一次 cleaner(由上层定时任务调用) func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) { - return service.Cleaner.RunOnce(ctx) + return job.Cleaner.RunOnce(ctx) } diff --git a/dao/model_dao.go b/dao/model_dao.go index 92e6d91..b6644bd 100644 --- a/dao/model_dao.go +++ b/dao/model_dao.go @@ -5,6 +5,7 @@ import ( "model-gateway/consts/public" "model-gateway/model/dto" "model-gateway/model/entity" + "strconv" "gitea.com/red-future/common/db/gfdb" "github.com/gogf/gf/v2/frame/g" @@ -90,22 +91,28 @@ func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchMode // GetByCreatorAndPlatform 按创建者、平台获取 func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) { - // 基础 SQL sql := ` SELECT DISTINCT ON (model_name) * FROM asynch_models WHERE deleted_at IS NULL AND (? = '' OR model_name LIKE ?) - AND (? = 0 OR model_type = ?) ` args := []any{ req.ModelName, "%" + req.ModelName + "%", - req.ModelType, req.ModelType, } + + // modelType: 传 6 模糊匹配 6% + if req.ModelType > 0 { + prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位 + sql += ` AND model_type::text LIKE ? ` + args = append(args, prefix+"%") + } + if !g.IsEmpty(req.IsPrivate) { sql += ` AND is_private = ? ` args = append(args, req.IsPrivate) } + if req.IsOwner != nil && *req.IsOwner == 0 { if req.Enabled != nil && *req.Enabled == 1 { sql += ` AND creator = ? AND is_owner = ? AND enabled=1 ` @@ -114,9 +121,7 @@ WHERE deleted_at IS NULL } else { sql += ` AND creator = ? AND is_owner = ? ` } - - args = append(args, req.Creator) - args = append(args, req.IsOwner) + args = append(args, req.Creator, req.IsOwner) } else if req.IsOwner != nil && *req.IsOwner == 1 { if req.Enabled != nil && *req.Enabled == 1 { sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) ` @@ -125,11 +130,9 @@ WHERE deleted_at IS NULL } else { sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) ` } - args = append(args, req.Creator) - args = append(args, req.IsOwner) + args = append(args, req.Creator, req.IsOwner) } - // 最后拼接排序 sql += ` ORDER BY model_name, is_owner DESC, created_at DESC` r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...) diff --git a/main.go b/main.go index 61c0cdb..922fe7e 100644 --- a/main.go +++ b/main.go @@ -3,13 +3,14 @@ package main import ( "context" "model-gateway/model/dto" + "model-gateway/service/job" + "model-gateway/service/task" "os" "os/signal" "syscall" "time" "model-gateway/controller" - "model-gateway/service" "gitea.com/red-future/common/http" "gitea.com/red-future/common/jaeger" @@ -62,7 +63,7 @@ func startAutoRunner(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{ + if _, err := task.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{ BatchSize: batchSize, Goroutines: goroutines, }); err != nil { @@ -87,7 +88,7 @@ func startAutoRunner(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: - _, _ = service.Cleaner.RunOnce(ctx) + _, _ = job.Cleaner.RunOnce(ctx) } } }() diff --git a/model/dto/model_dto.go b/model/dto/model_dto.go index 693ed5f..3fc7de0 100644 --- a/model/dto/model_dto.go +++ b/model/dto/model_dto.go @@ -10,33 +10,33 @@ import ( // CreateModelReq 添加模型配置 type CreateModelReq struct { g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"` - ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"` - ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"` - BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"` - HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"` - HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"` - IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` - Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"` - IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` - IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` - OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` - TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` - ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` - QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` - ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"` - Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"` - RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` - ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` - ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"` - ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` - MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"` - QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"` - TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"` - ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"` - RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"` - RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"` - AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"` - Remark string `p:"remark" json:"remark" dc:"备注说明"` + ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"` + ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"` + BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"` + HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"` + HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"` + IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` + Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"` + IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` + IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` + ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` + QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` + ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"` + Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"` + RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"` + ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"` + ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` + MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"` + QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"` + TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"` + ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"` + RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"` + RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"` + AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"` + Remark string `p:"remark" json:"remark" dc:"备注说明"` } type CreateModelRes struct { @@ -45,34 +45,34 @@ type CreateModelRes struct { type UpdateModelReq struct { g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"` - ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"` - ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` - ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"` - BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"` - HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"` - HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"` - ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"` - Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"` - RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"` - ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"` - ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"` - ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` - Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` - IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` - IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` - IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` - OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` - TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` - ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` - QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` - MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"` - QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"` - TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"` - ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"` - RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"` - RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"` - AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"` - Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"` + ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"` + ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"` + ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"` + BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"` + HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"` + HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"` + ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"` + Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"` + RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"` + ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"` + ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"` + ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"` + Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"` + IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"` + IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"` + IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"` + OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"` + TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"` + ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"` + QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"` + MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"` + QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"` + TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"` + ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"` + RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"` + RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"` + AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"` + Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"` } type UpdateModelRes struct { @@ -166,3 +166,20 @@ type GetIsChatModelReq struct { type GetIsChatModelRes struct { Model any `json:"model" dc:"模型详情"` } + +// NodeFormField 节点表单 +type NodeFormField struct { + Value any `json:"value" dc:"字段值"` + Field string `json:"field" dc:"字段标识"` + Label string `json:"label" dc:"字段标签"` + Type string `json:"type" dc:"字段类型"` + Required bool `json:"required" dc:"是否必填"` + Default any `json:"default,omitempty" dc:"默认值"` + Options []SelectOption `json:"options" dc:"下拉选项列表"` + FieldConstraint any `json:"fieldConstraint" dc:"字段约束"` +} + +type SelectOption struct { + Label string `json:"label" dc:"选项标签"` + Value string `json:"value" dc:"选项值"` +} diff --git a/model/entity/asynch_model.go b/model/entity/asynch_model.go index ab59c7d..7abc09c 100644 --- a/model/entity/asynch_model.go +++ b/model/entity/asynch_model.go @@ -69,32 +69,32 @@ var AsynchModelCol = asynchModelCol{ // AsynchModel 异步模型配置 type AsynchModel struct { beans.SQLBaseDO `orm:",inline"` - ModelName string `orm:"model_name" json:"modelName"` - ModelType int `orm:"model_type" json:"modelType"` - BaseURL string `orm:"base_url" json:"baseUrl"` - HttpMethod string `orm:"http_method" json:"httpMethod"` - HeadMsg string `orm:"head_msg" json:"headMsg"` - Form map[string]any `orm:"form_json" json:"form"` - RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` - ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` - ResponseBody map[string]any `orm:"response_body" json:"responseBody"` - ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` - Prompt string `orm:"prompt" json:"prompt"` - IsPrivate *int `orm:"is_private" json:"isPrivate"` - IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` - ApiKey string `orm:"api_key" json:"apiKey"` - Enabled *int `orm:"enabled" json:"enabled"` - MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` - QueueLimit int `orm:"queue_limit" json:"queueLimit"` - TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` - ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` - RetryTimes int `orm:"retry_times" json:"retryTimes"` - RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` - AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` - Remark string `orm:"remark" json:"remark"` - IsOwner *int `json:"isOwner" orm:"is_owner"` - OperatorName string `orm:"operator_name" json:"operatorName"` - TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` - ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` - QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` + ModelName string `orm:"model_name" json:"modelName"` + ModelType int `orm:"model_type" json:"modelType"` + BaseURL string `orm:"base_url" json:"baseUrl"` + HttpMethod string `orm:"http_method" json:"httpMethod"` + HeadMsg string `orm:"head_msg" json:"headMsg"` + Form []map[string]any `orm:"form_json" json:"form"` + RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"` + ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"` + ResponseBody map[string]any `orm:"response_body" json:"responseBody"` + ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"` + Prompt string `orm:"prompt" json:"prompt"` + IsPrivate *int `orm:"is_private" json:"isPrivate"` + IsChatModel *int `orm:"is_chat_model" json:"isChatModel"` + ApiKey string `orm:"api_key" json:"apiKey"` + Enabled *int `orm:"enabled" json:"enabled"` + MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"` + QueueLimit int `orm:"queue_limit" json:"queueLimit"` + TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"` + ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"` + RetryTimes int `orm:"retry_times" json:"retryTimes"` + RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"` + AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"` + Remark string `orm:"remark" json:"remark"` + IsOwner *int `json:"isOwner" orm:"is_owner"` + OperatorName string `orm:"operator_name" json:"operatorName"` + TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"` + ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"` + QueryConfig map[string]any `orm:"query_config" json:"queryConfig"` } diff --git a/service/file_detect.go b/service/file_detect.go deleted file mode 100644 index 6d43c33..0000000 --- a/service/file_detect.go +++ /dev/null @@ -1 +0,0 @@ -package service diff --git a/service/cleaner.go b/service/job/cleaner.go similarity index 72% rename from service/cleaner.go rename to service/job/cleaner.go index 39bac53..b6b7d47 100644 --- a/service/cleaner.go +++ b/service/job/cleaner.go @@ -1,8 +1,9 @@ -package service +package job import ( "context" "model-gateway/model/dto" + "model-gateway/service/queue" "os" "time" @@ -20,32 +21,32 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error // 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS) expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200) if err != nil { - g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err) + g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err) } else { for _, t := range expired { _ = os.Remove(t.TmpFile) _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) } - g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired)) + g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired)) } // 2) 超时任务标失败 list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200) if err != nil { - g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err) + g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err) } else { for _, t := range list { t.ErrorMsg = "任务超时自动失败" _ = dao.Task.UpdateFailedGlobal(ctx, t) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) + queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) } - g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list)) + g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list)) } // 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾) retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200) if err != nil { - g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err) + g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err) } else { for _, t := range retryable { // 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试 @@ -54,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error if err != nil || m == nil { continue } - limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit) + limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit) if limit > 0 { - ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds) + ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds) if !ok { continue } @@ -76,21 +77,21 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error } _ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt) } - g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable)) + g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable)) } // 4) 超过重试次数仍失败(state=3)的任务:硬删除 exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200) if err != nil { - g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err) + g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err) } else { for _, t := range exhausted { _ = os.Remove(t.TmpFile) // 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) + queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) _ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id) } - g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted)) + g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted)) } return &dto.CleanWorkRes{ Ok: true, diff --git a/service/model/model_service.go b/service/model/model_service.go new file mode 100644 index 0000000..0942948 --- /dev/null +++ b/service/model/model_service.go @@ -0,0 +1,254 @@ +package model + +import ( + "context" + "errors" + "model-gateway/common/util" + "model-gateway/consts/public" + "model-gateway/dao" + "model-gateway/model/dto" + "model-gateway/model/entity" + "model-gateway/service/gateway" + + "gitea.com/red-future/common/beans" + "gitea.com/red-future/common/db/gfdb" + "gitea.com/red-future/common/utils" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +var Model = &modelService{} + +type modelService struct{} + +// Create 创建模型 +func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) { + // 1)如果设为会话模型,先把该用户旧会话模型取消 + if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 { + if err := s.clearUserChatModel(ctx); err != nil { + return nil, err + } + } + // 2)判断是否超管,决定 isOwner + req.IsOwner = gconv.PtrInt(1) + if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { + req.IsOwner = gconv.PtrInt(0) + } + + // 3)入库 + id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) + if err != nil { + return nil, err + } + return &dto.CreateModelRes{ID: id}, nil +} + +// Update 更新模型配置 +func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error { + // 1)会话模型唯一性校验 + if req.IsChatModel != nil && *req.IsChatModel == 1 { + if err := s.checkChatModelUnique(ctx); err != nil { + return err + } + } + // 2)超管创建/普通用户更新 + req.IsOwner = gconv.PtrInt(1) + if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { + req.IsOwner = gconv.PtrInt(0) + _, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) + return err + } + // 3)跨租户判断:超管的模型不允许直接修改,走插入新记录 + model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + }) + if err != nil { + return err + } + if model.TenantId == 1 { + _, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req)) + return err + } + _, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req)) + return err +} + +// Delete 删除模型 +func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { + _, err := dao.Model.Delete(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, + }) + return err +} + +// Get 获取模型详情 +func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + if g.IsEmpty(req.ID) { + req.Creator = user.UserName + } + modelReq := new(entity.AsynchModel) + err = gconv.Struct(req, modelReq) + if err != nil { + return nil, err + } + model, err := dao.Model.Get(ctx, modelReq) + if err != nil { + return nil, err + } + return &dto.GetModelRes{ + Model: model, + }, nil +} + +// List 获取模型列表 +func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) { + // 1)判断超管 + req.IsOwner = gconv.PtrInt(1) + if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin { + req.IsOwner = gconv.PtrInt(0) + } + + // 2)获取当前用户 + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + req.Creator = user.UserName + + // 3)查询 + models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) + if err != nil { + return nil, err + } + + return &dto.ListModelRes{List: models, Total: total}, nil +} + +// UpdateChatModel 设置会话模型 +func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { + // 1)校验新模型存在 + newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, + }) + if err != nil || newModel == nil { + return errors.New("新会话模型不存在") + } + + // 2)获取当前用户的会话模型 + user, err := utils.GetUserInfo(ctx) + if err != nil { + return err + } + currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + IsChatModel: new(1), + }) + if err != nil { + return err + } + + // 3)事务:取消旧的 + 设置新的 + return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { + if !g.IsEmpty(currentModel) { + if currentModel.ModelType != public.ModelTypeInference { + return errors.New("当前模型为非推理模型,不能设置为会话模型") + } + if currentModel.Id != req.Id { + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, + IsChatModel: gconv.PtrInt(0), + }) + if err != nil { + return err + } + } + } + + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, + IsChatModel: gconv.PtrInt(1), + }) + return err + }) +} + +// GetIsChatModel 获取当前用户会话模型 +func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + IsChatModel: new(1), + }) + if err != nil || model == nil { + return nil, err + } + return &dto.GetIsChatModelRes{Model: model}, nil +} + +// ==================== 辅助方法 ==================== + +// clearUserChatModel 清除当前用户旧会话模型 +func (s *modelService) clearUserChatModel(ctx context.Context) error { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return err + } + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + IsChatModel: new(1), + }) + if err != nil || model == nil { + return nil + } + _, err = dao.Model.Update(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, + IsChatModel: gconv.PtrInt(0), + }) + return err +} + +// checkChatModelUnique 校验用户是否已有会话模型 +func (s *modelService) checkChatModelUnique(ctx context.Context) error { + user, err := utils.GetUserInfo(ctx) + if err != nil { + return err + } + model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName}, + IsChatModel: new(1), + }) + if err != nil { + return err + } + if model != nil { + return errors.New("用户已存在会话模型") + } + return nil +} + +// GetModelTypesFromConfig 从配置文件读取模型类型 +func GetModelTypesFromConfig() (res *dto.TypeItem, err error) { + // 返回副本,避免外部修改 + types := make(map[int]string, len(public.ModelTypeName)) + for k, v := range public.ModelTypeName { + types[k] = v + } + return &dto.TypeItem{ + Type: types, + }, nil +} + +// GetOperatorList 获取运营商列表 +func GetOperatorList() (res *dto.ListOperatorRes, err error) { + return &dto.ListOperatorRes{ + List: public.OperatorList, + }, nil +} diff --git a/service/model_invoker.go b/service/model_invoker.go deleted file mode 100644 index 8d3ef64..0000000 --- a/service/model_invoker.go +++ /dev/null @@ -1,469 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "model-gateway/model/entity" - "net/http" - "net/url" - "strings" - "time" - - "github.com/gogf/gf/v2/container/gvar" - "github.com/gogf/gf/v2/frame/g" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔: -// 示例: -// - X-API-Key:qwen3-tts-key,operation:true,count:123 -// - X-API-Key:"qwen3-tts-key",operation:"true" -// -// 说明: -// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。 -// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。 -func parseHeadMsgHeaders(headMsg string) map[string]string { - headMsg = strings.TrimSpace(headMsg) - if headMsg == "" { - return nil - } - out := map[string]string{} - parts := strings.Split(headMsg, ",") - for _, p := range parts { - p = strings.TrimSpace(p) - if p == "" { - continue - } - // HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容) - if strings.Contains(p, ":") { - kv := strings.SplitN(p, ":", 2) - k := strings.TrimSpace(kv[0]) - v := strings.TrimSpace(kv[1]) - v = strings.Trim(v, "\"") - if k != "" && v != "" { - out[k] = v - } - continue - } - if strings.Contains(p, "=") { - kv := strings.SplitN(p, "=", 2) - k := strings.TrimSpace(kv[0]) - v := strings.TrimSpace(kv[1]) - v = strings.Trim(v, "\"") - if k != "" && v != "" { - out[k] = v - } - continue - } - } - if len(out) == 0 { - return nil - } - return out -} - -func payloadToQuery(payload any) (url.Values, error) { - if payload == nil { - return url.Values{}, nil - } - // 统一转成 map[string]any - b, err := json.Marshal(payload) - if err != nil { - return nil, err - } - m := map[string]any{} - if err := json.Unmarshal(b, &m); err != nil { - return nil, err - } - q := url.Values{} - for k, v := range m { - if v == nil { - continue - } - // 复杂类型直接 json 字符串化 - switch vv := v.(type) { - case string: - q.Set(k, vv) - case float64, bool, int, int64, uint64: - q.Set(k, fmt.Sprintf("%v", vv)) - default: - bs, _ := json.Marshal(v) - q.Set(k, string(bs)) - } - } - return q, nil -} - -// InvokeModel 调用模型服务,返回二进制结果 -// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。 -func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { - if m == nil || m.BaseURL == "" { - return nil, fmt.Errorf("模型配置不完整") - } - - // ============ 新增:请求参数映射 ============ - mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) - if err != nil { - return nil, fmt.Errorf("请求参数映射失败: %w", err) - } - - url := strings.TrimRight(m.BaseURL, "/") - timeout := time.Duration(m.TimeoutSeconds) * time.Second - if timeout <= 0 { - timeout = 60 * time.Second - } - client := &http.Client{Timeout: timeout} - - method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) - if method == "" { - method = http.MethodPost - } - - var ( - req *http.Request - ) - switch method { - case http.MethodGet: - q, err := payloadToQuery(mappedPayload) // 使用映射后的payload - if err != nil { - return nil, err - } - if len(q) > 0 { - if strings.Contains(url, "?") { - url = url + "&" + q.Encode() - } else { - url = url + "?" + q.Encode() - } - } - req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - default: - bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload - if err != nil { - return nil, err - } - req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) - } - if err != nil { - return nil, err - } - - // 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key) - for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { - req.Header.Set(hk, hv) - } - - // 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。 - for hk, hv := range parseHeadMsgHeaders(modelKey) { - req.Header.Set(hk, hv) - } - - if method != http.MethodGet { - req.Header.Set("Content-Type", "application/json") - } - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - b, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - msg := string(b) - if len(msg) > 2000 { - msg = msg[:2000] - } - return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) - } - - // ============ 新增:响应参数映射 ============ - mappedResponse, err := mapResponsePayload(m.ResponseMapping, b) - if err != nil { - // 响应映射失败不阻塞,返回原始数据 - g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) - return b, nil - } - // ========================================= - - return mappedResponse, nil -} - -//// InvokeModel 调用模型服务,返回二进制结果 -//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { -// if m == nil || m.BaseURL == "" { -// return nil, fmt.Errorf("模型配置不完整") -// } -// // 请求参数映射 -// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) -// if err != nil { -// return nil, fmt.Errorf("请求参数映射失败: %w", err) -// } -// // 合并请求头 -// headers := util.ForwardHeaders(ctx) -// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { -// headers[hk] = hv -// } -// for hk, hv := range parseHeadMsgHeaders(modelKey) { -// headers[hk] = hv -// } -// -// // 设置超时 -// timeout := time.Duration(m.TimeoutSeconds) * time.Second -// if timeout <= 0 { -// timeout = 600 * time.Second -// } -// ctx, cancel := context.WithTimeout(ctx, timeout) -// defer cancel() -// -// invokeUrl := strings.TrimRight(m.BaseURL, "/") -// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) -// if method == "" { -// method = http.MethodPost -// } -// -// var respBytes []byte -// -// switch method { -// case http.MethodGet: -// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) -// default: -// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) -// } -// if err != nil { -// return nil, err -// } -// // 响应参数映射 -// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) -// if err != nil { -// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) -// return respBytes, nil -// } -// return mappedResponse, nil -//} - -// ============================================ -// 映射相关函数 -// ============================================ - -// mapRequestPayload 将标准请求映射为模型特定格式 -func mapRequestPayload(mappingAny any, payload any) (any, error) { - // 1. 解析请求映射配置(值是any类型,支持bool、number等) - mapping, err := parseRequestMapping(mappingAny) - if err != nil { - return nil, err - } - - // 如果没有映射配置,直接返回原始payload - if len(mapping) == 0 { - return payload, nil - } - - // 2. 将payload转为map - var payloadMap map[string]any - switch v := payload.(type) { - case map[string]any: - payloadMap = v - case []map[string]any: - // 如果传进来的是纯messages数组,包装成标准格式 - payloadMap = map[string]any{ - "messages": v, - } - default: - // 通过JSON转换 - jsonBytes, err := json.Marshal(payload) - if err != nil { - return nil, fmt.Errorf("序列化payload失败: %w", err) - } - if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil { - return nil, fmt.Errorf("反序列化payload失败: %w", err) - } - } - - // 3. 用数据库固定参数覆盖/补充 - for key, value := range mapping { - if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) { - payloadMap[key] = value - } - } - - return payloadMap, nil -} - -// mapResponsePayload 将模型响应映射为标准格式 -func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) { - mapping, err := parseResponseMapping(mappingAny) - if err != nil { - return nil, err - } - if len(mapping) == 0 { - return responseBytes, nil - } - - responseStr := string(responseBytes) - resultStr := `{}` - - for standardField, modelPath := range mapping { - value := gjson.Get(responseStr, modelPath) - if !value.Exists() { - continue - } - - resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw) - if err != nil { - return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err) - } - } - - return []byte(resultStr), nil -} - -func parseRequestMapping(mappingAny any) (map[string]any, error) { - if mappingAny == nil { - return nil, nil - } - - result := make(map[string]any) - - switch v := mappingAny.(type) { - case *gvar.Var: - if v == nil || v.IsNil() || v.IsEmpty() { - return nil, nil - } - // 尝试转成 map - if m := v.Map(); m != nil { - for k, val := range m { - result[k] = val - } - return result, nil - } - // 尝试转成 string - if s := v.String(); s != "" && s != "{}" && s != "null" { - if err := json.Unmarshal([]byte(s), &result); err != nil { - return nil, fmt.Errorf("解析请求映射字符串失败: %w", err) - } - return result, nil - } - return nil, nil - // ======================================================= - - case map[string]interface{}: - result = v - - case string: - if v == "" || v == "{}" || v == "null" { - return nil, nil - } - if err := json.Unmarshal([]byte(v), &result); err != nil { - return nil, fmt.Errorf("解析请求映射字符串失败: %w", err) - } - - case []byte: - if len(v) == 0 { - return nil, nil - } - if err := json.Unmarshal(v, &result); err != nil { - return nil, fmt.Errorf("解析请求映射字节失败: %w", err) - } - - default: - jsonBytes, err := json.Marshal(mappingAny) - if err != nil { - return nil, fmt.Errorf("序列化映射配置失败: %w", err) - } - if err := json.Unmarshal(jsonBytes, &result); err != nil { - return nil, fmt.Errorf("解析映射配置失败: %w", err) - } - } - - return result, nil -} - -// parseResponseMapping 解析响应映射配置 -// 返回值类型为 map[string]string,值都是JSON路径字符串 -func parseResponseMapping(mappingAny any) (map[string]string, error) { - if mappingAny == nil { - return nil, nil - } - - mapping := make(map[string]string) - - switch v := mappingAny.(type) { - case *gvar.Var: - if v == nil || v.IsNil() || v.IsEmpty() { - return nil, nil - } - if m := v.Map(); m != nil { - for k, val := range m { - if strVal, ok := val.(string); ok { - mapping[k] = strVal - } - } - return mapping, nil - } - if s := v.String(); s != "" && s != "{}" && s != "null" { - if err := json.Unmarshal([]byte(s), &mapping); err != nil { - return nil, fmt.Errorf("解析响应映射字符串失败: %w", err) - } - return mapping, nil - } - return nil, nil - case string: - if v == "" || v == "{}" || v == "null" { - return nil, nil - } - if err := json.Unmarshal([]byte(v), &mapping); err != nil { - return nil, fmt.Errorf("解析响应映射字符串失败: %w", err) - } - - case map[string]interface{}: - // 数据库JSONB直接返回的map - for k, val := range v { - if strVal, ok := val.(string); ok { - mapping[k] = strVal - } - } - - case []byte: - if len(v) == 0 { - return nil, nil - } - if err := json.Unmarshal(v, &mapping); err != nil { - return nil, fmt.Errorf("解析响应映射字节失败: %w", err) - } - - default: - jsonBytes, err := json.Marshal(mappingAny) - if err != nil { - return nil, fmt.Errorf("序列化响应映射配置失败: %w", err) - } - if err := json.Unmarshal(jsonBytes, &mapping); err != nil { - return nil, fmt.Errorf("解析响应映射配置失败: %w", err) - } - } - - return mapping, nil -} - -// isEmptyValue 判断值是否为空 -func isEmptyValue(v any) bool { - if v == nil { - return true - } - switch val := v.(type) { - case string: - return val == "" - case []any: - return len(val) == 0 - case map[string]any: - return len(val) == 0 - default: - return false - } -} diff --git a/service/model_service.go b/service/model_service.go deleted file mode 100644 index 7334e13..0000000 --- a/service/model_service.go +++ /dev/null @@ -1,389 +0,0 @@ -package service - -import ( - "context" - "errors" - "model-gateway/consts/public" - "model-gateway/dao" - "model-gateway/model/dto" - "model-gateway/model/entity" - "model-gateway/service/gateway" - - "gitea.com/red-future/common/beans" - "gitea.com/red-future/common/db/gfdb" - "gitea.com/red-future/common/utils" - "github.com/gogf/gf/v2/database/gdb" - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/util/gconv" -) - -var Model = &modelService{} - -type modelService struct{} - -func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { - // 获取当前会话模型 - if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 { - var user *beans.User - user, err = utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - // 获取当前用户会话模型 - var model *entity.AsynchModel - model, err = dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{ - Creator: user.UserName, - }, - IsChatModel: new(1), - }) - if err != nil { - return nil, err - } - // 如果有会话模型,那就改变为 0 - if model != nil { - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, - IsChatModel: gconv.PtrInt(0), - }) - if err != nil { - return nil, err - } - } - } - - req.IsOwner = gconv.PtrInt(1) - admin, err := gateway.IsSuperAdmin(ctx) - if err != nil { - return - } - if admin { - req.IsOwner = gconv.PtrInt(0) - } - id, err := dao.Model.Insert(ctx, &entity.AsynchModel{ - ModelName: req.ModelName, - ModelType: req.ModelType, - BaseURL: req.BaseURL, - HttpMethod: req.HttpMethod, - HeadMsg: req.HeadMsg, - Form: req.Form, - RequestMapping: req.RequestMapping, - ResponseMapping: req.ResponseMapping, - ResponseBody: req.ResponseBody, - ResponseTokenField: req.ResponseTokenField, - IsPrivate: req.IsPrivate, - IsChatModel: req.IsChatModel, - ApiKey: req.ApiKey, - Enabled: req.Enabled, - MaxConcurrency: req.MaxConcurrency, - QueueLimit: req.QueueLimit, - TimeoutSeconds: req.TimeoutSeconds, - ExpectedSeconds: req.ExpectedSeconds, - RetryTimes: req.RetryTimes, - RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, - AutoCleanSeconds: req.AutoCleanSeconds, - Remark: req.Remark, - IsOwner: req.IsOwner, - OperatorName: req.OperatorName, - TokenConfig: req.TokenConfig, - ExtendMapping: req.ExtendMapping, - QueryConfig: req.QueryConfig, - }) - if err != nil { - return nil, err - } - return &dto.CreateModelRes{ID: id}, nil -} - -func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error { - //根据当前 isChatModel 来判断是否更新模型 - if req.IsChatModel == gconv.PtrInt(1) { - user, err := utils.GetUserInfo(ctx) - if err != nil { - return err - } - // 获取当前用户会话模型 - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{ - Creator: user.UserName, - }, - IsChatModel: new(1), - }) - if err != nil { - return err - } - if model != nil { - return errors.New("用户已存在会话模型,不能创建") - } - } - - req.IsOwner = gconv.PtrInt(1) - admin, err := gateway.IsSuperAdmin(ctx) - if err != nil { - return err - } - if admin { - req.IsOwner = gconv.PtrInt(0) - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, - ModelName: req.ModelName, - ModelType: req.ModelType, - BaseURL: req.BaseURL, - HttpMethod: req.HttpMethod, - HeadMsg: req.HeadMsg, - Form: req.Form, - RequestMapping: req.RequestMapping, - ResponseMapping: req.ResponseMapping, - ResponseBody: req.ResponseBody, - ResponseTokenField: req.ResponseTokenField, - IsPrivate: req.IsPrivate, - IsChatModel: req.IsChatModel, - ApiKey: req.ApiKey, - Enabled: req.Enabled, - MaxConcurrency: req.MaxConcurrency, - QueueLimit: req.QueueLimit, - TimeoutSeconds: req.TimeoutSeconds, - ExpectedSeconds: req.ExpectedSeconds, - RetryTimes: req.RetryTimes, - RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, - AutoCleanSeconds: req.AutoCleanSeconds, - Remark: req.Remark, - IsOwner: req.IsOwner, - OperatorName: req.OperatorName, - TokenConfig: req.TokenConfig, - ExtendMapping: req.ExtendMapping, - QueryConfig: req.QueryConfig, - }) - if err != nil { - return err - } - return nil - } - // 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新 - model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, - }) - if err != nil { - return err - } - if model.TenantId == 1 { - insertDto := new(dto.CreateModelReq) - err = gconv.Struct(req, insertDto) - if err != nil { - return err - } - _, err = dao.Model.Insert(ctx, &entity.AsynchModel{ - ModelName: req.ModelName, - ModelType: req.ModelType, - BaseURL: req.BaseURL, - HttpMethod: req.HttpMethod, - HeadMsg: req.HeadMsg, - Form: req.Form, - RequestMapping: req.RequestMapping, - ResponseMapping: req.ResponseMapping, - ResponseBody: req.ResponseBody, - ResponseTokenField: req.ResponseTokenField, - IsPrivate: req.IsPrivate, - IsChatModel: req.IsChatModel, - ApiKey: req.ApiKey, - Enabled: req.Enabled, - MaxConcurrency: req.MaxConcurrency, - QueueLimit: req.QueueLimit, - TimeoutSeconds: req.TimeoutSeconds, - ExpectedSeconds: req.ExpectedSeconds, - RetryTimes: req.RetryTimes, - RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, - AutoCleanSeconds: req.AutoCleanSeconds, - Remark: req.Remark, - IsOwner: req.IsOwner, - OperatorName: req.OperatorName, - TokenConfig: req.TokenConfig, - ExtendMapping: req.ExtendMapping, - QueryConfig: req.QueryConfig, - }) - return err - } - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, - ModelName: req.ModelName, - ModelType: req.ModelType, - BaseURL: req.BaseURL, - HttpMethod: req.HttpMethod, - HeadMsg: req.HeadMsg, - Form: req.Form, - RequestMapping: req.RequestMapping, - ResponseMapping: req.ResponseMapping, - ResponseBody: req.ResponseBody, - ResponseTokenField: req.ResponseTokenField, - IsPrivate: req.IsPrivate, - IsChatModel: req.IsChatModel, - ApiKey: req.ApiKey, - Enabled: req.Enabled, - MaxConcurrency: req.MaxConcurrency, - QueueLimit: req.QueueLimit, - TimeoutSeconds: req.TimeoutSeconds, - ExpectedSeconds: req.ExpectedSeconds, - RetryTimes: req.RetryTimes, - RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, - AutoCleanSeconds: req.AutoCleanSeconds, - Remark: req.Remark, - IsOwner: req.IsOwner, - OperatorName: req.OperatorName, - TokenConfig: req.TokenConfig, - ExtendMapping: req.ExtendMapping, - QueryConfig: req.QueryConfig, - }) - return err -} - -func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { - _, err := dao.Model.Delete(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, - }) - return err -} - -func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) { - user, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{ - Id: req.ID, - Creator: user.UserName, - }, - ModelName: req.ModelName, - }) - if err != nil { - return nil, err - } - return &dto.GetModelRes{ - Model: model, - }, nil -} - -func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { - var models []*entity.AsynchModel - - req.IsOwner = gconv.PtrInt(1) - admin, err := gateway.IsSuperAdmin(ctx) - if err != nil { - return - } - if admin { - req.IsOwner = gconv.PtrInt(0) - } - - var user *beans.User - user, err = utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - req.Creator = user.UserName - - models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) - if err != nil { - return - } - - return &dto.ListModelRes{ - List: models, - Total: total, - }, nil -} - -// GetModelTypesFromConfig 从配置文件读取模型类型 -func GetModelTypesFromConfig() (res *dto.TypeItem, err error) { - // 返回副本,避免外部修改 - types := make(map[int]string, len(public.ModelTypeName)) - for k, v := range public.ModelTypeName { - types[k] = v - } - return &dto.TypeItem{ - Type: types, - }, nil -} - -// GetOperatorList 获取运营商列表 -func GetOperatorList() (res *dto.ListOperatorRes, err error) { - return &dto.ListOperatorRes{ - List: public.OperatorList, - }, nil -} - -func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { - // 校验新会话模型是否存在 - newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, - }) - if err != nil { - return err - } - if newModel == nil { - return errors.New("新会话模型不存在") - } - var user *beans.User - user, err = utils.GetUserInfo(ctx) - if err != nil { - return err - } - // 获取当前用户会话模型 - currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{ - Creator: user.UserName, - }, - IsChatModel: new(1), - }) - if err != nil { - return err - } - err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { - if !g.IsEmpty(currentModel) { - if currentModel.ModelType != public.ModelTypeInference { - return errors.New("当前模型为非推理模型,不能设置为会话模型") - } - - // 如果点击的就是当前会话模型(已经是1),取消它(设为0) - if currentModel.Id != req.Id { - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, - IsChatModel: gconv.PtrInt(0), - }) - if err != nil { - return err - } - } - } - - // 设置当前为会话模型(设为1) - _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, - IsChatModel: gconv.PtrInt(1), - }) - return err - }) - return err -} - -func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) { - user, err := utils.GetUserInfo(ctx) - if err != nil { - return nil, err - } - model, err := dao.Model.Get(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{ - Creator: user.UserName, - }, - IsChatModel: new(1), - }) - if err != nil { - return nil, err - } - if model == nil { - return nil, nil - } - return &dto.GetIsChatModelRes{ - Model: model, - }, nil -} diff --git a/service/auto_tune.go b/service/queue/auto_tune.go similarity index 99% rename from service/auto_tune.go rename to service/queue/auto_tune.go index 481b910..f3d06ea 100644 --- a/service/auto_tune.go +++ b/service/queue/auto_tune.go @@ -1,4 +1,4 @@ -package service +package queue import ( "context" diff --git a/service/queue_gate.go b/service/queue/queue_gate.go similarity index 99% rename from service/queue_gate.go rename to service/queue/queue_gate.go index d9998aa..9e0eb51 100644 --- a/service/queue_gate.go +++ b/service/queue/queue_gate.go @@ -1,4 +1,4 @@ -package service +package queue import ( "context" diff --git a/service/runtime_tune.go b/service/queue/runtime_tune.go similarity index 85% rename from service/runtime_tune.go rename to service/queue/runtime_tune.go index 276fe8c..ec1dd5d 100644 --- a/service/runtime_tune.go +++ b/service/queue/runtime_tune.go @@ -1,4 +1,4 @@ -package service +package queue import ( "context" @@ -11,9 +11,9 @@ import ( // 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。 const ( - runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name - runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name - runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退 + runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name + runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name + runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退 ) func runtimeMaxConcurrencyKey(modelName string) string { @@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int { } return v } - diff --git a/service/semaphore.go b/service/queue/semaphore.go similarity index 81% rename from service/semaphore.go rename to service/queue/semaphore.go index e97a9d4..6f818e6 100644 --- a/service/semaphore.go +++ b/service/queue/semaphore.go @@ -1,4 +1,4 @@ -package service +package queue import ( "context" @@ -34,7 +34,8 @@ end return 1 ` -func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) { +// AcquireSemaphore 获取并发令牌 +func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) { if max <= 0 { // 不限制 return true, nil @@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64 return gconv.Int(r) == 1, nil } -func releaseSemaphore(ctx context.Context, key string) error { +// ReleaseSemaphore 释放并发令牌 +func ReleaseSemaphore(ctx context.Context, key string) error { _, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key) return err } - diff --git a/service/stat_service.go b/service/stat/stat_service.go similarity index 98% rename from service/stat_service.go rename to service/stat/stat_service.go index 9aaf7c0..16b4446 100644 --- a/service/stat_service.go +++ b/service/stat/stat_service.go @@ -1,4 +1,4 @@ -package service +package stat import ( "context" diff --git a/service/task_service.go b/service/task/task_service.go similarity index 74% rename from service/task_service.go rename to service/task/task_service.go index 9390588..fe65c9a 100644 --- a/service/task_service.go +++ b/service/task/task_service.go @@ -1,9 +1,10 @@ -package service +package task import ( "context" "errors" "model-gateway/common/util" + "model-gateway/service/queue" "time" "model-gateway/dao" @@ -20,10 +21,11 @@ var Task = &taskService{} type taskService struct{} +// Create 创建任务 func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) { startAt := time.Now() - // 固化 token/user 等信息 - ctx = util.AsyncCtx(ctx) + taskID := uuid.NewString() + // 1) 检查模型配置 m, err := dao.Model.Get(ctx, &entity.AsynchModel{ ModelName: req.ModelName, @@ -35,11 +37,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * return nil, errors.New("模型不存在或未启用") } - taskID := uuid.NewString() // 2) 排队上限(严格控制:Redis 原子闸门) - limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit) + limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit) if limit > 0 { - ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds) + ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds) if err != nil { return nil, err } @@ -48,13 +49,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * } } - // 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用 + // 3) 插入任务记录 storedPayload := map[string]any{ "payload": req.RequestPayload, "headers": util.ForwardHeaders(ctx), } - - t := &entity.AsynchTask{ + _, err = dao.Task.Insert(ctx, &entity.AsynchTask{ ModelName: req.ModelName, TaskID: taskID, State: 0, @@ -64,21 +64,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * InputRef: req.InputRef, RequestPayload: storedPayload, EpicycleId: req.EpicycleId, - } - _, err = dao.Task.Insert(ctx, t) + }) if err != nil { // 入库失败:回滚闸门占位 - ReleaseQueueSlot(ctx, req.ModelName, taskID) + queue.ReleaseQueueSlot(ctx, req.ModelName, taskID) return nil, err } - // 3) 写操作日志(尽量不影响主流程,失败忽略) + // 4) 写操作日志(不影响主流程,失败忽略) ip := "" ua := "" apiPath := "/task/createTask" httpMethod := "POST" if r := g.RequestFromCtx(ctx); r != nil { - ip = r.GetClientIp() + ip = util.GetLocalIP() ua = r.UserAgent() apiPath = r.URL.Path httpMethod = r.Method @@ -101,70 +100,68 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res * }, }) - // 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。 + // 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。 // 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。 - go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req) + go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req) return &dto.CreateTaskRes{TaskID: taskID}, nil } -// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”: +// pollAndRunUntilPicked 定向轮询执行刚创建的任务 // - 目标:尽快把刚创建的任务拉起来执行 // - 只在任务仍为 pending(state=0) 时继续尝试抢占 // - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止 -// - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务 +// - 不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务 func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) { - if taskID == "" { - return - } - interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int() - if interval <= 0 { - interval = 5 - } - g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval) + interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int() + pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int() + pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second) + defer cancel() ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() + g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout) tryRun := func() bool { t, err := dao.Task.Get(ctx, &entity.AsynchTask{ TaskID: taskID, }) if err != nil { - g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err) + g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err) return true } if t == nil { - g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID) + g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID) return true } + switch t.State { case 0: + //RunByTaskID 尝试执行任务 if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil { - g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err) + g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err) } else { - g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID) + g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID) } return false case 1: - g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID) + g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID) return true case 2, 3, 4: - g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State) + g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State) return true default: - g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State) + g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State) return true } } - - // 先立即尝试一次 + // 立即尝试一次 if stop := tryRun(); stop { return } for { select { - case <-ctx.Done(): - g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID) + case <-pollCtx.Done(): + g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID) return case <-ticker.C: if stop := tryRun(); stop { @@ -174,6 +171,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, } } +// GetResult 获取任务结果 func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) { t, err := dao.Task.Get(ctx, &entity.AsynchTask{ TaskID: taskID, @@ -244,6 +242,7 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r return &dto.GetTaskBatchRes{List: items}, nil } +// List 获取任务列表 func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) { pageNum, pageSize := 1, 10 if req != nil { diff --git a/service/worker.go b/service/task/worker.go similarity index 53% rename from service/worker.go rename to service/task/worker.go index 331d114..1a2b03e 100644 --- a/service/worker.go +++ b/service/task/worker.go @@ -1,12 +1,17 @@ -package service +package task import ( + "bytes" "context" + "encoding/json" "errors" "fmt" + "io" "model-gateway/common/util" "model-gateway/model/dto" "model-gateway/service/gateway" + "model-gateway/service/queue" + "net/http" "os" "path/filepath" "strings" @@ -56,7 +61,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt if e != nil { task.ErrorMsg = fmt.Sprintf("worker panic: %v", e) _ = dao.Task.UpdateFailedGlobal(ctx, task) - ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) + queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID) } done <- struct{}{} }) @@ -100,8 +105,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * // 2) 分布式并发控制 semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName) - maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency) - acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600) + maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency) + acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600) if err != nil { w.failTask(ctx, t, err.Error()) return @@ -111,7 +116,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * _ = w.rollbackToPending(ctx, t.Id) return } - defer func() { _ = releaseSemaphore(ctx, semKey) }() + defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }() // 3) request_payload 校验 if payload == nil { @@ -146,31 +151,32 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * } // 6) 解析校验(可重试,失败重新调模型) - if req.BuildType == 1 { - for attempt := 0; attempt <= maxRetry; attempt++ { - if attempt > 0 { - g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) - } - err = util.ValidatePromptResult(textResult, model.RequestMapping) - if err == nil { - break - } - g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", - t.TaskID, attempt, maxRetry, err) - if attempt == maxRetry { - w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) - return - } - // 重新调模型 - newResult, modelErr := w.callModel(ctx, t, model, payload) - if modelErr != nil { - g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", - t.TaskID, attempt, maxRetry, modelErr) - continue - } - textResult = newResult - } - } + //if req.BuildType == 1 { + // for attempt := 0; attempt <= maxRetry; attempt++ { + // if attempt > 0 { + // g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID) + // } + // // 6.1) 校验数据 + // err = util.ValidatePromptResult(textResult, model) + // if err == nil { + // break + // } + // g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v", + // t.TaskID, attempt, maxRetry, err) + // if attempt == maxRetry { + // w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err)) + // return + // } + // // 6.2) 重新调模型 + // newResult, modelErr := w.callModel(ctx, t, model, payload) + // if modelErr != nil { + // g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v", + // t.TaskID, attempt, maxRetry, modelErr) + // continue + // } + // textResult = newResult + // } + //} // 7) 成功回调 t.State = 2 @@ -185,7 +191,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * return } - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) + queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) if req.EpicycleId != 0 { go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId) @@ -198,29 +204,29 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req * // 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试) // callModel 调用模型 + 检测文件类型 + 保存临时文件 -func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) { +func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) { var data []byte var contentType, ext, textResult string var err error - if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" { - data, err = os.ReadFile(t.TmpFile) + if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" { + data, err = os.ReadFile(task.TmpFile) if err != nil || len(data) == 0 { data = nil } } if data == nil { - _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName) - data, err = InvokeModel(ctx, m, payload, t.ModelKey) + _ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName) + data, err = InvokeModel(ctx, model, payload, task.ModelKey) if err != nil { return nil, err } - tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext) + tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext) if tmpErr == nil && tmpPath != "" { - t.TmpFile = tmpPath - t.Phase = 1 - _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath) + task.TmpFile = tmpPath + task.Phase = 1 + _ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath) } } @@ -228,10 +234,138 @@ func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *en if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") { textResult = string(data) } - return gjson.New(textResult).Map(), nil } +// InvokeModel 调用模型服务,返回二进制结果 +// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key) +func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) { + // 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式 + //mappedPayload := util.ReverseMap(model.RequestMapping, payload) + + // 2)构建请求 URL 和超时 + baseURL := strings.TrimRight(model.BaseURL, "/") + timeout := time.Duration(model.TimeoutSeconds) * time.Second + client := &http.Client{Timeout: timeout} + method := strings.ToUpper(strings.TrimSpace(model.HttpMethod)) + + // 3)构建 HTTP 请求 + var req *http.Request + switch method { + case http.MethodGet: + q, err := util.PayloadToQuery(payload) + if err != nil { + return nil, err + } + if len(q) > 0 { + if strings.Contains(baseURL, "?") { + baseURL = baseURL + "&" + q.Encode() + } else { + baseURL = baseURL + "?" + q.Encode() + } + } + req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil) + default: + bodyBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes)) + } + + // 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者) + for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) { + req.Header.Set(hk, hv) + } + for hk, hv := range util.ParseHeadMsgHeaders(modelKey) { + req.Header.Set(hk, hv) + } + if method != http.MethodGet { + req.Header.Set("Content-Type", "application/json") + } + + // 5)发送请求 + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // 6)读取响应体 + b, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // 7)检查 HTTP 状态码 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + msg := string(b) + return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) + } + + // 8)响应参数映射 + mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b) + if err != nil { + g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) + return b, nil + } + return mappedResponse, nil +} + +// // InvokeModel 调用模型服务,返回二进制结果 +// +// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) { +// if m == nil || m.BaseURL == "" { +// return nil, fmt.Errorf("模型配置不完整") +// } +// // 请求参数映射 +// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload) +// if err != nil { +// return nil, fmt.Errorf("请求参数映射失败: %w", err) +// } +// // 合并请求头 +// headers := util.ForwardHeaders(ctx) +// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) { +// headers[hk] = hv +// } +// for hk, hv := range parseHeadMsgHeaders(modelKey) { +// headers[hk] = hv +// } +// +// // 设置超时 +// timeout := time.Duration(m.TimeoutSeconds) * time.Second +// if timeout <= 0 { +// timeout = 600 * time.Second +// } +// ctx, cancel := context.WithTimeout(ctx, timeout) +// defer cancel() +// +// invokeUrl := strings.TrimRight(m.BaseURL, "/") +// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) +// if method == "" { +// method = http.MethodPost +// } +// +// var respBytes []byte +// +// switch method { +// case http.MethodGet: +// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload) +// default: +// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload) +// } +// if err != nil { +// return nil, err +// } +// // 响应参数映射 +// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes) +// if err != nil { +// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err) +// return respBytes, nil +// } +// return mappedResponse, nil +// } + // uploadOSS 从临时文件上传 OSS func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) { data, err := os.ReadFile(t.TmpFile) @@ -247,7 +381,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg t.State = 3 t.ErrorMsg = errMsg _ = dao.Task.UpdateFailedGlobal(ctx, t) - ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) + queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID) go gateway.TriggerCallback(context.WithoutCancel(ctx), t) }