refactor(service): 重构服务模块结构并优化模型配置
This commit is contained in:
10
common/util/convert.go
Normal file
10
common/util/convert.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package util
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// ConvertTo 转换为指定类型
|
||||
func ConvertTo[T any](v interface{}) *T {
|
||||
var t T
|
||||
_ = gconv.Struct(v, &t)
|
||||
return &t
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ValidatePromptResult 完整的校验逻辑
|
||||
func ValidatePromptResult(raw map[string]any, requestMapping map[string]any) error {
|
||||
contentStr, ok := raw["content"].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return fmt.Errorf("content 字段为空或不是字符串")
|
||||
}
|
||||
|
||||
var rounds []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &rounds); err != nil {
|
||||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||||
}
|
||||
if len(rounds) == 0 {
|
||||
return fmt.Errorf("content 数组为空")
|
||||
}
|
||||
|
||||
// 对 rounds 中的每一个元素进行结构校验
|
||||
for i, round := range rounds {
|
||||
if err := validateStructure(requestMapping, round); err != nil {
|
||||
return fmt.Errorf("rounds[%d] 结构校验失败: %w", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateStructure 递归校验 actual 是否包含 expected 定义的所有字段路径
|
||||
func validateStructure(expected any, actual any) error {
|
||||
switch exp := expected.(type) {
|
||||
case map[string]any:
|
||||
act, ok := actual.(map[string]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("期望对象,实际类型 %T", actual)
|
||||
}
|
||||
for key, expVal := range exp {
|
||||
actVal, exists := act[key]
|
||||
if !exists {
|
||||
return fmt.Errorf("缺少字段: %s", key)
|
||||
}
|
||||
if err := validateStructure(expVal, actVal); err != nil {
|
||||
return fmt.Errorf("%s: %w", key, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case []any:
|
||||
act, ok := actual.([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("期望数组,实际类型 %T", actual)
|
||||
}
|
||||
if len(exp) == 0 {
|
||||
return nil // 空数组模板,只校验类型
|
||||
}
|
||||
// 用第一个元素的结构去校验每个实际元素
|
||||
for i, actItem := range act {
|
||||
if err := validateStructure(exp[0], actItem); err != nil {
|
||||
return fmt.Errorf("[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
// 基本类型,不校验具体值,只检查存在
|
||||
return nil
|
||||
}
|
||||
}
|
||||
151
common/util/mapping.go
Normal file
151
common/util/mapping.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"model-gateway/model/entity"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||
// 1) 获取校验配置,并取值
|
||||
requestMapping := model.RequestMapping
|
||||
contentKey := ""
|
||||
for k := range model.ResponseBody {
|
||||
contentKey = k
|
||||
break
|
||||
}
|
||||
contentStr, ok := raw[contentKey].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
|
||||
}
|
||||
|
||||
// 2) 解析 content 为 JSON 数组
|
||||
var rounds []map[string]any
|
||||
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||||
}
|
||||
if len(rounds) == 0 {
|
||||
return fmt.Errorf("content 数组为空")
|
||||
}
|
||||
|
||||
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||||
for i, round := range rounds {
|
||||
for path, defaultValue := range requestMapping {
|
||||
if !g.IsEmpty(defaultValue) {
|
||||
continue
|
||||
}
|
||||
if gjson.New(round).Get(path).IsNil() {
|
||||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReverseMap 映射 payload 到 mapping
|
||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
jsonObj := gjson.New("{}")
|
||||
for path, defaultValue := range mapping {
|
||||
// 从 payload 取对应路径的值
|
||||
val := gjson.New(payload).Get(path)
|
||||
if !val.IsNil() {
|
||||
// payload 有值,用它
|
||||
_ = jsonObj.Set(path, val.Val())
|
||||
} else if !g.IsEmpty(defaultValue) {
|
||||
// payload 没值,用默认值
|
||||
_ = jsonObj.Set(path, defaultValue)
|
||||
}
|
||||
}
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseJson := gjson.New(responseBytes)
|
||||
resultJson := gjson.New("{}")
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
val := responseJson.Get(path)
|
||||
if val.IsNil() {
|
||||
continue
|
||||
}
|
||||
resultJson.Set(standardField, val.Val())
|
||||
}
|
||||
|
||||
return []byte(resultJson.String()), nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PayloadToQuery 将 payload 转为 url.Values
|
||||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
q.Set(k, gconv.String(v))
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
140
common/util/network.go
Normal file
140
common/util/network.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// GetLocalIP 获取本机有效的局域网 IPv4 地址
|
||||
func GetLocalIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
var validIPs []string
|
||||
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := ipnet.IP
|
||||
|
||||
if isIPValid(ip) {
|
||||
validIPs = append(validIPs, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
// 优先返回非 169.254.x.x 的 IP
|
||||
for _, ip := range validIPs {
|
||||
if !strings.HasPrefix(ip, "169.254.") {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 其次返回 169.254.x.x(最后的选择)
|
||||
if len(validIPs) > 0 {
|
||||
return validIPs[0]
|
||||
}
|
||||
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
// isIPValid 判断 IP 是否有效
|
||||
func isIPValid(ip net.IP) bool {
|
||||
// 不是 loopback (127.0.0.1)
|
||||
if ip.IsLoopback() {
|
||||
return false
|
||||
}
|
||||
|
||||
// 是 IPv4
|
||||
if ip.To4() == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不是链路本地地址 (169.254.0.0/16)
|
||||
if ip[0] == 169 && ip[1] == 254 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不是组播地址
|
||||
if ip.IsMulticast() {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不是未指定地址 (0.0.0.0)
|
||||
if ip.IsUnspecified() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetLocalAddress 获取局域网地址(IP:端口)
|
||||
func GetLocalAddress(ctx context.Context) string {
|
||||
ip := GetLocalIP()
|
||||
port := GetServerPort(ctx)
|
||||
|
||||
if port == "80" || port == "443" {
|
||||
return ip
|
||||
}
|
||||
return ip + ":" + port
|
||||
}
|
||||
|
||||
// GetSchemaFromRequest 从当前请求中获取协议(http/https)
|
||||
func GetSchemaFromRequest(ctx context.Context) string {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
if r == nil {
|
||||
return "http"
|
||||
}
|
||||
|
||||
// 1. 代理场景:X-Forwarded-Proto
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
return proto
|
||||
}
|
||||
|
||||
// 2. 代理场景:X-Forwarded-Scheme
|
||||
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
|
||||
return proto
|
||||
}
|
||||
|
||||
// 3. TLS 连接(直接 HTTPS)
|
||||
if r.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
// 4. 默认 HTTP(这行很重要!)
|
||||
return "http" // ← 确保有这行
|
||||
}
|
||||
|
||||
// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口)
|
||||
func GetLocalBaseURL(ctx context.Context) string {
|
||||
schema := GetSchemaFromRequest(ctx)
|
||||
addr := GetLocalAddress(ctx)
|
||||
return schema + "://" + addr
|
||||
}
|
||||
|
||||
// GetCallbackURL 获取回调地址(完整 URL)
|
||||
func GetCallbackURL(ctx context.Context, path string) string {
|
||||
baseURL := GetLocalBaseURL(ctx)
|
||||
// 确保 path 以 / 开头
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return baseURL + path
|
||||
}
|
||||
|
||||
// GetServerPort 从配置获取服务端口
|
||||
func GetServerPort(ctx context.Context) string {
|
||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||
// address 格式如 ":3009",去掉冒号
|
||||
if strings.HasPrefix(address, ":") {
|
||||
return address[1:]
|
||||
}
|
||||
return "8080"
|
||||
}
|
||||
@@ -4,11 +4,12 @@ package public
|
||||
const (
|
||||
ModelTypeInference = 100 // 推理模型
|
||||
|
||||
ModelTypeImage = 200 // 图片模型
|
||||
ImageSubTypeTextToImage = 201 // 图片模型-文生图
|
||||
ImageSubTypeImageToImage = 202 // 图片模型-图生图
|
||||
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
|
||||
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
|
||||
ModelTypeImage = 200 // 图片模型
|
||||
ImageSubTypeTextToImage = 201 // 图片模型-文生图
|
||||
ImageSubTypeImageToImage = 202 // 图片模型-图生图
|
||||
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
|
||||
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
|
||||
ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图
|
||||
|
||||
ModelTypeAudio = 300 // 音频模型
|
||||
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
|
||||
@@ -35,11 +36,12 @@ const (
|
||||
var ModelTypeName = map[int]string{
|
||||
ModelTypeInference: "推理模型",
|
||||
|
||||
ModelTypeImage: "图片模型",
|
||||
ImageSubTypeTextToImage: "图片模型-文生图",
|
||||
ImageSubTypeImageToImage: "图片模型-图生图",
|
||||
ImageSubTypeImageEdit: "图片模型-图片编辑",
|
||||
ImageSubTypeImageVariation: "图片模型-图片变体",
|
||||
ModelTypeImage: "图片模型",
|
||||
ImageSubTypeTextToImage: "图片模型-文生图",
|
||||
ImageSubTypeImageToImage: "图片模型-图生图",
|
||||
ImageSubTypeImageEdit: "图片模型-图片编辑",
|
||||
ImageSubTypeImageVariation: "图片模型-图片变体",
|
||||
ImageSubTypeImageTextToImage: "图片模型-图文生图",
|
||||
|
||||
ModelTypeAudio: "音频模型",
|
||||
AudioSubTypeTextToSpeech: "音频模型-文生音",
|
||||
|
||||
@@ -2,9 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service"
|
||||
modelService "model-gateway/service/model"
|
||||
"model-gateway/service/queue"
|
||||
)
|
||||
|
||||
type model struct{}
|
||||
@@ -14,53 +14,53 @@ var Model = new(model)
|
||||
|
||||
// CreateModel 添加配置
|
||||
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
return service.Model.Create(ctx, req)
|
||||
return modelService.Model.Create(ctx, req)
|
||||
}
|
||||
|
||||
// UpdateModel 更改配置
|
||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
||||
err = service.Model.Update(ctx, req)
|
||||
err = modelService.Model.Update(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteModel 删除配置
|
||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
||||
err = service.Model.Delete(ctx, req)
|
||||
err = modelService.Model.Delete(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetModel 获取配置详情
|
||||
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
||||
return service.Model.Get(ctx, req)
|
||||
return modelService.Model.Get(ctx, req)
|
||||
}
|
||||
|
||||
// ListModel 配置列表
|
||||
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
return service.Model.List(ctx, req)
|
||||
return modelService.Model.List(ctx, req)
|
||||
}
|
||||
|
||||
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
||||
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
return service.AutoTune(ctx, req)
|
||||
return queue.AutoTune(ctx, req)
|
||||
}
|
||||
|
||||
// ListType 模型类型列表
|
||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
||||
return service.GetModelTypesFromConfig()
|
||||
return modelService.GetModelTypesFromConfig()
|
||||
}
|
||||
|
||||
// ListOperator 运营商列表
|
||||
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
|
||||
return service.GetOperatorList()
|
||||
return modelService.GetOperatorList()
|
||||
}
|
||||
|
||||
// UpdateChatModel 更新是否为聊天模型
|
||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
||||
err = service.Model.UpdateChatModel(ctx, req)
|
||||
err = modelService.Model.UpdateChatModel(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
// GetIsChatModel 获取当前会话模型
|
||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
||||
return service.Model.GetIsChatModel(ctx)
|
||||
return modelService.Model.GetIsChatModel(ctx)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,9 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
statService "model-gateway/service/stat"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service"
|
||||
)
|
||||
|
||||
type stat struct{}
|
||||
@@ -14,5 +14,5 @@ var Stat = new(stat)
|
||||
|
||||
// ListModelStat 统计列表
|
||||
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
||||
return service.Stat.List(ctx, req)
|
||||
return statService.Stat.List(ctx, req)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/service/job"
|
||||
taskService "model-gateway/service/task"
|
||||
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service"
|
||||
)
|
||||
|
||||
type task struct{}
|
||||
@@ -14,30 +15,30 @@ var Task = new(task)
|
||||
|
||||
// CreateTask 根据 modelName 创建异步任务,返回 taskId
|
||||
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
return service.Task.Create(ctx, req)
|
||||
return taskService.Task.Create(ctx, req)
|
||||
}
|
||||
|
||||
// GetTaskResult 获取任务结果(只返回 oss 地址 + state)
|
||||
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
|
||||
return service.Task.GetResult(ctx, req.TaskID)
|
||||
return taskService.Task.GetResult(ctx, req.TaskID)
|
||||
}
|
||||
|
||||
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
|
||||
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
return service.Task.GetBatch(ctx, req)
|
||||
return taskService.Task.GetBatch(ctx, req)
|
||||
}
|
||||
|
||||
// ListTask 任务列表分页查询
|
||||
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
return service.Task.List(ctx, req)
|
||||
return taskService.Task.List(ctx, req)
|
||||
}
|
||||
|
||||
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
||||
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
return service.AsyncWorker.RunOnce(ctx, req)
|
||||
return taskService.AsyncWorker.RunOnce(ctx, req)
|
||||
}
|
||||
|
||||
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
||||
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
||||
return service.Cleaner.RunOnce(ctx)
|
||||
return job.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"strconv"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -90,22 +91,28 @@ func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchMode
|
||||
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
// 基础 SQL
|
||||
sql := `
|
||||
SELECT DISTINCT ON (model_name) *
|
||||
FROM asynch_models
|
||||
WHERE deleted_at IS NULL
|
||||
AND (? = '' OR model_name LIKE ?)
|
||||
AND (? = 0 OR model_type = ?)
|
||||
`
|
||||
args := []any{
|
||||
req.ModelName, "%" + req.ModelName + "%",
|
||||
req.ModelType, req.ModelType,
|
||||
}
|
||||
|
||||
// modelType: 传 6 模糊匹配 6%
|
||||
if req.ModelType > 0 {
|
||||
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
|
||||
sql += ` AND model_type::text LIKE ? `
|
||||
args = append(args, prefix+"%")
|
||||
}
|
||||
|
||||
if !g.IsEmpty(req.IsPrivate) {
|
||||
sql += ` AND is_private = ? `
|
||||
args = append(args, req.IsPrivate)
|
||||
}
|
||||
|
||||
if req.IsOwner != nil && *req.IsOwner == 0 {
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
|
||||
@@ -114,9 +121,7 @@ WHERE deleted_at IS NULL
|
||||
} else {
|
||||
sql += ` AND creator = ? AND is_owner = ? `
|
||||
}
|
||||
|
||||
args = append(args, req.Creator)
|
||||
args = append(args, req.IsOwner)
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
} else if req.IsOwner != nil && *req.IsOwner == 1 {
|
||||
if req.Enabled != nil && *req.Enabled == 1 {
|
||||
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
|
||||
@@ -125,11 +130,9 @@ WHERE deleted_at IS NULL
|
||||
} else {
|
||||
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
|
||||
}
|
||||
args = append(args, req.Creator)
|
||||
args = append(args, req.IsOwner)
|
||||
args = append(args, req.Creator, req.IsOwner)
|
||||
}
|
||||
|
||||
// 最后拼接排序
|
||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
|
||||
7
main.go
7
main.go
@@ -3,13 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/job"
|
||||
"model-gateway/service/task"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"model-gateway/controller"
|
||||
"model-gateway/service"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
@@ -62,7 +63,7 @@ func startAutoRunner(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
||||
if _, err := task.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
||||
BatchSize: batchSize,
|
||||
Goroutines: goroutines,
|
||||
}); err != nil {
|
||||
@@ -87,7 +88,7 @@ func startAutoRunner(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
_, _ = service.Cleaner.RunOnce(ctx)
|
||||
_, _ = job.Cleaner.RunOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -10,33 +10,33 @@ import (
|
||||
// CreateModelReq 添加模型配置
|
||||
type CreateModelReq struct {
|
||||
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||
Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明"`
|
||||
}
|
||||
|
||||
type CreateModelRes struct {
|
||||
@@ -45,34 +45,34 @@ type CreateModelRes struct {
|
||||
|
||||
type UpdateModelReq struct {
|
||||
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
|
||||
Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"`
|
||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
|
||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"`
|
||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
|
||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"`
|
||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
|
||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
|
||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
|
||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
|
||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||
}
|
||||
|
||||
type UpdateModelRes struct {
|
||||
@@ -166,3 +166,20 @@ type GetIsChatModelReq struct {
|
||||
type GetIsChatModelRes struct {
|
||||
Model any `json:"model" dc:"模型详情"`
|
||||
}
|
||||
|
||||
// NodeFormField 节点表单
|
||||
type NodeFormField struct {
|
||||
Value any `json:"value" dc:"字段值"`
|
||||
Field string `json:"field" dc:"字段标识"`
|
||||
Label string `json:"label" dc:"字段标签"`
|
||||
Type string `json:"type" dc:"字段类型"`
|
||||
Required bool `json:"required" dc:"是否必填"`
|
||||
Default any `json:"default,omitempty" dc:"默认值"`
|
||||
Options []SelectOption `json:"options" dc:"下拉选项列表"`
|
||||
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
|
||||
}
|
||||
|
||||
type SelectOption struct {
|
||||
Label string `json:"label" dc:"选项标签"`
|
||||
Value string `json:"value" dc:"选项值"`
|
||||
}
|
||||
|
||||
@@ -69,32 +69,32 @@ var AsynchModelCol = asynchModelCol{
|
||||
// AsynchModel 异步模型配置
|
||||
type AsynchModel struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
||||
Form map[string]any `orm:"form_json" json:"form"`
|
||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody map[string]any `orm:"response_body" json:"responseBody"`
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||
ModelName string `orm:"model_name" json:"modelName"`
|
||||
ModelType int `orm:"model_type" json:"modelType"`
|
||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
||||
Form []map[string]any `orm:"form_json" json:"form"`
|
||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||
ResponseBody map[string]any `orm:"response_body" json:"responseBody"`
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
Prompt string `orm:"prompt" json:"prompt"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||
Remark string `orm:"remark" json:"remark"`
|
||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||
}
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
package service
|
||||
@@ -1,8 +1,9 @@
|
||||
package service
|
||||
package job
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/queue"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -20,32 +21,32 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
|
||||
}
|
||||
|
||||
// 2) 超时任务标失败
|
||||
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range list {
|
||||
t.ErrorMsg = "任务超时自动失败"
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
|
||||
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
|
||||
}
|
||||
|
||||
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
||||
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range retryable {
|
||||
// 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试
|
||||
@@ -54,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
||||
if err != nil || m == nil {
|
||||
continue
|
||||
}
|
||||
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
||||
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@@ -76,21 +77,21 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
||||
}
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
||||
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
|
||||
}
|
||||
|
||||
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
|
||||
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
|
||||
}
|
||||
return &dto.CleanWorkRes{
|
||||
Ok: true,
|
||||
254
service/model/model_service.go
Normal file
254
service/model/model_service.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// Create 创建模型
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
|
||||
// 1)如果设为会话模型,先把该用户旧会话模型取消
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
if err := s.clearUserChatModel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// 2)判断是否超管,决定 isOwner
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
// 3)入库
|
||||
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
// Update 更新模型配置
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
// 1)会话模型唯一性校验
|
||||
if req.IsChatModel != nil && *req.IsChatModel == 1 {
|
||||
if err := s.checkChatModelUnique(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// 2)超管创建/普通用户更新
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
// 3)跨租户判断:超管的模型不允许直接修改,走插入新记录
|
||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model.TenantId == 1 {
|
||||
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除模型
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Get 获取模型详情
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.IsEmpty(req.ID) {
|
||||
req.Creator = user.UserName
|
||||
}
|
||||
modelReq := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, modelReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, modelReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// List 获取模型列表
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
|
||||
// 1)判断超管
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
// 2)获取当前用户
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
// 3)查询
|
||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dto.ListModelRes{List: models, Total: total}, nil
|
||||
}
|
||||
|
||||
// UpdateChatModel 设置会话模型
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 1)校验新模型存在
|
||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
})
|
||||
if err != nil || newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 2)获取当前用户的会话模型
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 3)事务:取消旧的 + 设置新的
|
||||
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
if !g.IsEmpty(currentModel) {
|
||||
if currentModel.ModelType != public.ModelTypeInference {
|
||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||
}
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// GetIsChatModel 获取当前用户会话模型
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil || model == nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetIsChatModelRes{Model: model}, nil
|
||||
}
|
||||
|
||||
// ==================== 辅助方法 ====================
|
||||
|
||||
// clearUserChatModel 清除当前用户旧会话模型
|
||||
func (s *modelService) clearUserChatModel(ctx context.Context) error {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil || model == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// checkChatModelUnique 校验用户是否已有会话模型
|
||||
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model != nil {
|
||||
return errors.New("用户已存在会话模型")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||
// 返回副本,避免外部修改
|
||||
types := make(map[int]string, len(public.ModelTypeName))
|
||||
for k, v := range public.ModelTypeName {
|
||||
types[k] = v
|
||||
}
|
||||
return &dto.TypeItem{
|
||||
Type: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetOperatorList 获取运营商列表
|
||||
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
||||
return &dto.ListOperatorRes{
|
||||
List: public.OperatorList,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,469 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"model-gateway/model/entity"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func parseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
if payload == nil {
|
||||
return url.Values{}, nil
|
||||
}
|
||||
// 统一转成 map[string]any
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := map[string]any{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := url.Values{}
|
||||
for k, v := range m {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
// 复杂类型直接 json 字符串化
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
q.Set(k, vv)
|
||||
case float64, bool, int, int64, uint64:
|
||||
q.Set(k, fmt.Sprintf("%v", vv))
|
||||
default:
|
||||
bs, _ := json.Marshal(v)
|
||||
q.Set(k, string(bs))
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。
|
||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
|
||||
// ============ 新增:请求参数映射 ============
|
||||
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(m.BaseURL, "/")
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(url, "?") {
|
||||
url = url + "&" + q.Encode()
|
||||
} else {
|
||||
url = url + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key)
|
||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
// 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。
|
||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
|
||||
// ============ 新增:响应参数映射 ============
|
||||
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
|
||||
if err != nil {
|
||||
// 响应映射失败不阻塞,返回原始数据
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
// =========================================
|
||||
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
//// InvokeModel 调用模型服务,返回二进制结果
|
||||
//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
// if m == nil || m.BaseURL == "" {
|
||||
// return nil, fmt.Errorf("模型配置不完整")
|
||||
// }
|
||||
// // 请求参数映射
|
||||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
// }
|
||||
// // 合并请求头
|
||||
// headers := util.ForwardHeaders(ctx)
|
||||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
//
|
||||
// // 设置超时
|
||||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
// if timeout <= 0 {
|
||||
// timeout = 600 * time.Second
|
||||
// }
|
||||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
// defer cancel()
|
||||
//
|
||||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
// if method == "" {
|
||||
// method = http.MethodPost
|
||||
// }
|
||||
//
|
||||
// var respBytes []byte
|
||||
//
|
||||
// switch method {
|
||||
// case http.MethodGet:
|
||||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// default:
|
||||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// }
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // 响应参数映射
|
||||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||||
// if err != nil {
|
||||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
// return respBytes, nil
|
||||
// }
|
||||
// return mappedResponse, nil
|
||||
//}
|
||||
|
||||
// ============================================
|
||||
// 映射相关函数
|
||||
// ============================================
|
||||
|
||||
// mapRequestPayload 将标准请求映射为模型特定格式
|
||||
func mapRequestPayload(mappingAny any, payload any) (any, error) {
|
||||
// 1. 解析请求映射配置(值是any类型,支持bool、number等)
|
||||
mapping, err := parseRequestMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果没有映射配置,直接返回原始payload
|
||||
if len(mapping) == 0 {
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 2. 将payload转为map
|
||||
var payloadMap map[string]any
|
||||
switch v := payload.(type) {
|
||||
case map[string]any:
|
||||
payloadMap = v
|
||||
case []map[string]any:
|
||||
// 如果传进来的是纯messages数组,包装成标准格式
|
||||
payloadMap = map[string]any{
|
||||
"messages": v,
|
||||
}
|
||||
default:
|
||||
// 通过JSON转换
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
|
||||
return nil, fmt.Errorf("反序列化payload失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用数据库固定参数覆盖/补充
|
||||
for key, value := range mapping {
|
||||
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
|
||||
payloadMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return payloadMap, nil
|
||||
}
|
||||
|
||||
// mapResponsePayload 将模型响应映射为标准格式
|
||||
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
|
||||
mapping, err := parseResponseMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseStr := string(responseBytes)
|
||||
resultStr := `{}`
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
value := gjson.Get(responseStr, modelPath)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(resultStr), nil
|
||||
}
|
||||
|
||||
func parseRequestMapping(mappingAny any) (map[string]any, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
// 尝试转成 map
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
result[k] = val
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
// 尝试转成 string
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return nil, nil
|
||||
// =======================================================
|
||||
|
||||
case map[string]interface{}:
|
||||
result = v
|
||||
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseResponseMapping 解析响应映射配置
|
||||
// 返回值类型为 map[string]string,值都是JSON路径字符串
|
||||
func parseResponseMapping(mappingAny any) (map[string]string, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mapping := make(map[string]string)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
return nil, nil
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case map[string]interface{}:
|
||||
// 数据库JSONB直接返回的map
|
||||
for k, val := range v {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
// isEmptyValue 判断值是否为空
|
||||
func isEmptyValue(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val == ""
|
||||
case []any:
|
||||
return len(val) == 0
|
||||
case map[string]any:
|
||||
return len(val) == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
// 获取当前会话模型
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 获取当前用户会话模型
|
||||
var model *entity.AsynchModel
|
||||
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Creator: user.UserName,
|
||||
},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 如果有会话模型,那就改变为 0
|
||||
if model != nil {
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
//根据当前 isChatModel 来判断是否更新模型
|
||||
if req.IsChatModel == gconv.PtrInt(1) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 获取当前用户会话模型
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Creator: user.UserName,
|
||||
},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model != nil {
|
||||
return errors.New("用户已存在会话模型,不能创建")
|
||||
}
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model.TenantId == 1 {
|
||||
insertDto := new(dto.CreateModelReq)
|
||||
err = gconv.Struct(req, insertDto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Id: req.ID,
|
||||
Creator: user.UserName,
|
||||
},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
var models []*entity.AsynchModel
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return &dto.ListModelRes{
|
||||
List: models,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||
// 返回副本,避免外部修改
|
||||
types := make(map[int]string, len(public.ModelTypeName))
|
||||
for k, v := range public.ModelTypeName {
|
||||
types[k] = v
|
||||
}
|
||||
return &dto.TypeItem{
|
||||
Type: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetOperatorList 获取运营商列表
|
||||
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
||||
return &dto.ListOperatorRes{
|
||||
List: public.OperatorList,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 校验新会话模型是否存在
|
||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 获取当前用户会话模型
|
||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Creator: user.UserName,
|
||||
},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
if !g.IsEmpty(currentModel) {
|
||||
if currentModel.ModelType != public.ModelTypeInference {
|
||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||
}
|
||||
|
||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置当前为会话模型(设为1)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Creator: user.UserName,
|
||||
},
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &dto.GetIsChatModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
// 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。
|
||||
|
||||
const (
|
||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||
)
|
||||
|
||||
func runtimeMaxConcurrencyKey(modelName string) string {
|
||||
@@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -34,7 +34,8 @@ end
|
||||
return 1
|
||||
`
|
||||
|
||||
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
// AcquireSemaphore 获取并发令牌
|
||||
func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
if max <= 0 {
|
||||
// 不限制
|
||||
return true, nil
|
||||
@@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64
|
||||
return gconv.Int(r) == 1, nil
|
||||
}
|
||||
|
||||
func releaseSemaphore(ctx context.Context, key string) error {
|
||||
// ReleaseSemaphore 释放并发令牌
|
||||
func ReleaseSemaphore(ctx context.Context, key string) error {
|
||||
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package service
|
||||
package stat
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,9 +1,10 @@
|
||||
package service
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/service/queue"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -20,10 +21,11 @@ var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// Create 创建任务
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = util.AsyncCtx(ctx)
|
||||
taskID := uuid.NewString()
|
||||
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
@@ -35,11 +37,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
taskID := uuid.NewString()
|
||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||
if limit > 0 {
|
||||
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -48,13 +49,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
}
|
||||
}
|
||||
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
// 3) 插入任务记录
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": util.ForwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
@@ -64,21 +64,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
EpicycleId: req.EpicycleId,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, t)
|
||||
})
|
||||
if err != nil {
|
||||
// 入库失败:回滚闸门占位
|
||||
ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 写操作日志(尽量不影响主流程,失败忽略)
|
||||
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = r.GetClientIp()
|
||||
ip = util.GetLocalIP()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
@@ -101,70 +100,68 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
},
|
||||
})
|
||||
|
||||
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
||||
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req)
|
||||
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
|
||||
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
|
||||
// - 目标:尽快把刚创建的任务拉起来执行
|
||||
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
||||
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
||||
// - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||
// - 不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
|
||||
if taskID == "" {
|
||||
return
|
||||
}
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
|
||||
if interval <= 0 {
|
||||
interval = 5
|
||||
}
|
||||
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
|
||||
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
|
||||
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
|
||||
return true
|
||||
}
|
||||
if t == nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
|
||||
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
|
||||
return true
|
||||
}
|
||||
|
||||
switch t.State {
|
||||
case 0:
|
||||
//RunByTaskID 尝试执行任务
|
||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
|
||||
}
|
||||
return false
|
||||
case 1:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
|
||||
return true
|
||||
case 2, 3, 4:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
|
||||
return true
|
||||
default:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 先立即尝试一次
|
||||
// 立即尝试一次
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
|
||||
case <-pollCtx.Done():
|
||||
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if stop := tryRun(); stop {
|
||||
@@ -174,6 +171,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
}
|
||||
|
||||
// GetResult 获取任务结果
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
@@ -244,6 +242,7 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
// List 获取任务列表
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil {
|
||||
@@ -1,12 +1,17 @@
|
||||
package service
|
||||
package task
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/gateway"
|
||||
"model-gateway/service/queue"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -56,7 +61,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
|
||||
if e != nil {
|
||||
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, task)
|
||||
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||
}
|
||||
done <- struct{}{}
|
||||
})
|
||||
@@ -100,8 +105,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
|
||||
// 2) 分布式并发控制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
||||
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
|
||||
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
||||
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
return
|
||||
@@ -111,7 +116,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
return
|
||||
}
|
||||
defer func() { _ = releaseSemaphore(ctx, semKey) }()
|
||||
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||
|
||||
// 3) request_payload 校验
|
||||
if payload == nil {
|
||||
@@ -146,31 +151,32 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
}
|
||||
|
||||
// 6) 解析校验(可重试,失败重新调模型)
|
||||
if req.BuildType == 1 {
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
}
|
||||
err = util.ValidatePromptResult(textResult, model.RequestMapping)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
// 重新调模型
|
||||
newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||
if modelErr != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, modelErr)
|
||||
continue
|
||||
}
|
||||
textResult = newResult
|
||||
}
|
||||
}
|
||||
//if req.BuildType == 1 {
|
||||
// for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
// if attempt > 0 {
|
||||
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
// }
|
||||
// // 6.1) 校验数据
|
||||
// err = util.ValidatePromptResult(textResult, model)
|
||||
// if err == nil {
|
||||
// break
|
||||
// }
|
||||
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||
// t.TaskID, attempt, maxRetry, err)
|
||||
// if attempt == maxRetry {
|
||||
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
// return
|
||||
// }
|
||||
// // 6.2) 重新调模型
|
||||
// newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||
// if modelErr != nil {
|
||||
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||
// t.TaskID, attempt, maxRetry, modelErr)
|
||||
// continue
|
||||
// }
|
||||
// textResult = newResult
|
||||
// }
|
||||
//}
|
||||
|
||||
// 7) 成功回调
|
||||
t.State = 2
|
||||
@@ -185,7 +191,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
return
|
||||
}
|
||||
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
if req.EpicycleId != 0 {
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
|
||||
@@ -198,29 +204,29 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
|
||||
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||||
var data []byte
|
||||
var contentType, ext, textResult string
|
||||
var err error
|
||||
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = os.ReadFile(t.TmpFile)
|
||||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||
data, err = os.ReadFile(task.TmpFile)
|
||||
if err != nil || len(data) == 0 {
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
||||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
|
||||
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,10 +234,138 @@ func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *en
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
|
||||
return gjson.New(textResult).Map(), nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
|
||||
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||||
|
||||
// 2)构建请求 URL 和超时
|
||||
baseURL := strings.TrimRight(model.BaseURL, "/")
|
||||
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||
client := &http.Client{Timeout: timeout}
|
||||
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
|
||||
|
||||
// 3)构建 HTTP 请求
|
||||
var req *http.Request
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := util.PayloadToQuery(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(baseURL, "?") {
|
||||
baseURL = baseURL + "&" + q.Encode()
|
||||
} else {
|
||||
baseURL = baseURL + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
|
||||
// 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
// 5)发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 6)读取响应体
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 7)检查 HTTP 状态码
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
|
||||
// 8)响应参数映射
|
||||
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
// // InvokeModel 调用模型服务,返回二进制结果
|
||||
//
|
||||
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
// if m == nil || m.BaseURL == "" {
|
||||
// return nil, fmt.Errorf("模型配置不完整")
|
||||
// }
|
||||
// // 请求参数映射
|
||||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
// }
|
||||
// // 合并请求头
|
||||
// headers := util.ForwardHeaders(ctx)
|
||||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
//
|
||||
// // 设置超时
|
||||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
// if timeout <= 0 {
|
||||
// timeout = 600 * time.Second
|
||||
// }
|
||||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
// defer cancel()
|
||||
//
|
||||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
// if method == "" {
|
||||
// method = http.MethodPost
|
||||
// }
|
||||
//
|
||||
// var respBytes []byte
|
||||
//
|
||||
// switch method {
|
||||
// case http.MethodGet:
|
||||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// default:
|
||||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// }
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // 响应参数映射
|
||||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||||
// if err != nil {
|
||||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
// return respBytes, nil
|
||||
// }
|
||||
// return mappedResponse, nil
|
||||
// }
|
||||
|
||||
// uploadOSS 从临时文件上传 OSS
|
||||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||||
data, err := os.ReadFile(t.TmpFile)
|
||||
@@ -247,7 +381,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user