refactor(service): 重构服务模块结构并优化模型配置
This commit is contained in:
10
common/util/convert.go
Normal file
10
common/util/convert.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import "github.com/gogf/gf/v2/util/gconv"
|
||||||
|
|
||||||
|
// ConvertTo 转换为指定类型
|
||||||
|
func ConvertTo[T any](v interface{}) *T {
|
||||||
|
var t T
|
||||||
|
_ = gconv.Struct(v, &t)
|
||||||
|
return &t
|
||||||
|
}
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
package util
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ValidatePromptResult 完整的校验逻辑
|
|
||||||
func ValidatePromptResult(raw map[string]any, requestMapping map[string]any) error {
|
|
||||||
contentStr, ok := raw["content"].(string)
|
|
||||||
if !ok || contentStr == "" {
|
|
||||||
return fmt.Errorf("content 字段为空或不是字符串")
|
|
||||||
}
|
|
||||||
|
|
||||||
var rounds []map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(contentStr), &rounds); err != nil {
|
|
||||||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
|
||||||
}
|
|
||||||
if len(rounds) == 0 {
|
|
||||||
return fmt.Errorf("content 数组为空")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 对 rounds 中的每一个元素进行结构校验
|
|
||||||
for i, round := range rounds {
|
|
||||||
if err := validateStructure(requestMapping, round); err != nil {
|
|
||||||
return fmt.Errorf("rounds[%d] 结构校验失败: %w", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateStructure 递归校验 actual 是否包含 expected 定义的所有字段路径
|
|
||||||
func validateStructure(expected any, actual any) error {
|
|
||||||
switch exp := expected.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
act, ok := actual.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("期望对象,实际类型 %T", actual)
|
|
||||||
}
|
|
||||||
for key, expVal := range exp {
|
|
||||||
actVal, exists := act[key]
|
|
||||||
if !exists {
|
|
||||||
return fmt.Errorf("缺少字段: %s", key)
|
|
||||||
}
|
|
||||||
if err := validateStructure(expVal, actVal); err != nil {
|
|
||||||
return fmt.Errorf("%s: %w", key, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
case []any:
|
|
||||||
act, ok := actual.([]any)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("期望数组,实际类型 %T", actual)
|
|
||||||
}
|
|
||||||
if len(exp) == 0 {
|
|
||||||
return nil // 空数组模板,只校验类型
|
|
||||||
}
|
|
||||||
// 用第一个元素的结构去校验每个实际元素
|
|
||||||
for i, actItem := range act {
|
|
||||||
if err := validateStructure(exp[0], actItem); err != nil {
|
|
||||||
return fmt.Errorf("[%d]: %w", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
// 基本类型,不校验具体值,只检查存在
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
151
common/util/mapping.go
Normal file
151
common/util/mapping.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"model-gateway/model/entity"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||||
|
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||||||
|
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||||
|
// 1) 获取校验配置,并取值
|
||||||
|
requestMapping := model.RequestMapping
|
||||||
|
contentKey := ""
|
||||||
|
for k := range model.ResponseBody {
|
||||||
|
contentKey = k
|
||||||
|
break
|
||||||
|
}
|
||||||
|
contentStr, ok := raw[contentKey].(string)
|
||||||
|
if !ok || contentStr == "" {
|
||||||
|
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2) 解析 content 为 JSON 数组
|
||||||
|
var rounds []map[string]any
|
||||||
|
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||||||
|
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||||||
|
}
|
||||||
|
if len(rounds) == 0 {
|
||||||
|
return fmt.Errorf("content 数组为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||||||
|
for i, round := range rounds {
|
||||||
|
for path, defaultValue := range requestMapping {
|
||||||
|
if !g.IsEmpty(defaultValue) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if gjson.New(round).Get(path).IsNil() {
|
||||||
|
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseMap 映射 payload 到 mapping
|
||||||
|
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||||
|
jsonObj := gjson.New("{}")
|
||||||
|
for path, defaultValue := range mapping {
|
||||||
|
// 从 payload 取对应路径的值
|
||||||
|
val := gjson.New(payload).Get(path)
|
||||||
|
if !val.IsNil() {
|
||||||
|
// payload 有值,用它
|
||||||
|
_ = jsonObj.Set(path, val.Val())
|
||||||
|
} else if !g.IsEmpty(defaultValue) {
|
||||||
|
// payload 没值,用默认值
|
||||||
|
_ = jsonObj.Set(path, defaultValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonObj.Map()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapResponsePayload 映射模型响应为标准格式
|
||||||
|
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||||
|
if len(mapping) == 0 {
|
||||||
|
return responseBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
responseJson := gjson.New(responseBytes)
|
||||||
|
resultJson := gjson.New("{}")
|
||||||
|
|
||||||
|
for standardField, modelPath := range mapping {
|
||||||
|
path := gconv.String(modelPath)
|
||||||
|
if path == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
val := responseJson.Get(path)
|
||||||
|
if val.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resultJson.Set(standardField, val.Val())
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(resultJson.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||||
|
// 示例:
|
||||||
|
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||||
|
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||||
|
//
|
||||||
|
// 说明:
|
||||||
|
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||||
|
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||||
|
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||||
|
headMsg = strings.TrimSpace(headMsg)
|
||||||
|
if headMsg == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := map[string]string{}
|
||||||
|
parts := strings.Split(headMsg, ",")
|
||||||
|
for _, p := range parts {
|
||||||
|
p = strings.TrimSpace(p)
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||||
|
if strings.Contains(p, ":") {
|
||||||
|
kv := strings.SplitN(p, ":", 2)
|
||||||
|
k := strings.TrimSpace(kv[0])
|
||||||
|
v := strings.TrimSpace(kv[1])
|
||||||
|
v = strings.Trim(v, "\"")
|
||||||
|
if k != "" && v != "" {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(p, "=") {
|
||||||
|
kv := strings.SplitN(p, "=", 2)
|
||||||
|
k := strings.TrimSpace(kv[0])
|
||||||
|
v := strings.TrimSpace(kv[1])
|
||||||
|
v = strings.Trim(v, "\"")
|
||||||
|
if k != "" && v != "" {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayloadToQuery 将 payload 转为 url.Values
|
||||||
|
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||||
|
q := url.Values{}
|
||||||
|
for k, v := range payload {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
q.Set(k, gconv.String(v))
|
||||||
|
}
|
||||||
|
return q, nil
|
||||||
|
}
|
||||||
140
common/util/network.go
Normal file
140
common/util/network.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetLocalIP 获取本机有效的局域网 IPv4 地址
|
||||||
|
func GetLocalIP() string {
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
var validIPs []string
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipnet, ok := addr.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := ipnet.IP
|
||||||
|
|
||||||
|
if isIPValid(ip) {
|
||||||
|
validIPs = append(validIPs, ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 优先返回非 169.254.x.x 的 IP
|
||||||
|
for _, ip := range validIPs {
|
||||||
|
if !strings.HasPrefix(ip, "169.254.") {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 其次返回 169.254.x.x(最后的选择)
|
||||||
|
if len(validIPs) > 0 {
|
||||||
|
return validIPs[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPValid 判断 IP 是否有效
|
||||||
|
func isIPValid(ip net.IP) bool {
|
||||||
|
// 不是 loopback (127.0.0.1)
|
||||||
|
if ip.IsLoopback() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 是 IPv4
|
||||||
|
if ip.To4() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不是链路本地地址 (169.254.0.0/16)
|
||||||
|
if ip[0] == 169 && ip[1] == 254 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不是组播地址
|
||||||
|
if ip.IsMulticast() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不是未指定地址 (0.0.0.0)
|
||||||
|
if ip.IsUnspecified() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocalAddress 获取局域网地址(IP:端口)
|
||||||
|
func GetLocalAddress(ctx context.Context) string {
|
||||||
|
ip := GetLocalIP()
|
||||||
|
port := GetServerPort(ctx)
|
||||||
|
|
||||||
|
if port == "80" || port == "443" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
return ip + ":" + port
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSchemaFromRequest 从当前请求中获取协议(http/https)
|
||||||
|
func GetSchemaFromRequest(ctx context.Context) string {
|
||||||
|
r := g.RequestFromCtx(ctx)
|
||||||
|
if r == nil {
|
||||||
|
return "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 代理场景:X-Forwarded-Proto
|
||||||
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||||
|
return proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 代理场景:X-Forwarded-Scheme
|
||||||
|
if proto := r.Header.Get("X-Forwarded-Scheme"); proto != "" {
|
||||||
|
return proto
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. TLS 连接(直接 HTTPS)
|
||||||
|
if r.TLS != nil {
|
||||||
|
return "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 默认 HTTP(这行很重要!)
|
||||||
|
return "http" // ← 确保有这行
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocalBaseURL 获取局域网基础 URL(动态协议 + IP + 端口)
|
||||||
|
func GetLocalBaseURL(ctx context.Context) string {
|
||||||
|
schema := GetSchemaFromRequest(ctx)
|
||||||
|
addr := GetLocalAddress(ctx)
|
||||||
|
return schema + "://" + addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallbackURL 获取回调地址(完整 URL)
|
||||||
|
func GetCallbackURL(ctx context.Context, path string) string {
|
||||||
|
baseURL := GetLocalBaseURL(ctx)
|
||||||
|
// 确保 path 以 / 开头
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
return baseURL + path
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServerPort 从配置获取服务端口
|
||||||
|
func GetServerPort(ctx context.Context) string {
|
||||||
|
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||||
|
// address 格式如 ":3009",去掉冒号
|
||||||
|
if strings.HasPrefix(address, ":") {
|
||||||
|
return address[1:]
|
||||||
|
}
|
||||||
|
return "8080"
|
||||||
|
}
|
||||||
@@ -4,11 +4,12 @@ package public
|
|||||||
const (
|
const (
|
||||||
ModelTypeInference = 100 // 推理模型
|
ModelTypeInference = 100 // 推理模型
|
||||||
|
|
||||||
ModelTypeImage = 200 // 图片模型
|
ModelTypeImage = 200 // 图片模型
|
||||||
ImageSubTypeTextToImage = 201 // 图片模型-文生图
|
ImageSubTypeTextToImage = 201 // 图片模型-文生图
|
||||||
ImageSubTypeImageToImage = 202 // 图片模型-图生图
|
ImageSubTypeImageToImage = 202 // 图片模型-图生图
|
||||||
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
|
ImageSubTypeImageEdit = 203 // 图片模型-图片编辑
|
||||||
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
|
ImageSubTypeImageVariation = 204 // 图片模型-图片变体
|
||||||
|
ImageSubTypeImageTextToImage = 205 // 图片模型-图文生图
|
||||||
|
|
||||||
ModelTypeAudio = 300 // 音频模型
|
ModelTypeAudio = 300 // 音频模型
|
||||||
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
|
AudioSubTypeTextToSpeech = 301 // 音频模型-文生音
|
||||||
@@ -35,11 +36,12 @@ const (
|
|||||||
var ModelTypeName = map[int]string{
|
var ModelTypeName = map[int]string{
|
||||||
ModelTypeInference: "推理模型",
|
ModelTypeInference: "推理模型",
|
||||||
|
|
||||||
ModelTypeImage: "图片模型",
|
ModelTypeImage: "图片模型",
|
||||||
ImageSubTypeTextToImage: "图片模型-文生图",
|
ImageSubTypeTextToImage: "图片模型-文生图",
|
||||||
ImageSubTypeImageToImage: "图片模型-图生图",
|
ImageSubTypeImageToImage: "图片模型-图生图",
|
||||||
ImageSubTypeImageEdit: "图片模型-图片编辑",
|
ImageSubTypeImageEdit: "图片模型-图片编辑",
|
||||||
ImageSubTypeImageVariation: "图片模型-图片变体",
|
ImageSubTypeImageVariation: "图片模型-图片变体",
|
||||||
|
ImageSubTypeImageTextToImage: "图片模型-图文生图",
|
||||||
|
|
||||||
ModelTypeAudio: "音频模型",
|
ModelTypeAudio: "音频模型",
|
||||||
AudioSubTypeTextToSpeech: "音频模型-文生音",
|
AudioSubTypeTextToSpeech: "音频模型-文生音",
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/service"
|
modelService "model-gateway/service/model"
|
||||||
|
"model-gateway/service/queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
type model struct{}
|
type model struct{}
|
||||||
@@ -14,53 +14,53 @@ var Model = new(model)
|
|||||||
|
|
||||||
// CreateModel 添加配置
|
// CreateModel 添加配置
|
||||||
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
func (c *model) CreateModel(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||||
return service.Model.Create(ctx, req)
|
return modelService.Model.Create(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateModel 更改配置
|
// UpdateModel 更改配置
|
||||||
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
func (c *model) UpdateModel(ctx context.Context, req *dto.UpdateModelReq) (res *dto.UpdateModelRes, err error) {
|
||||||
err = service.Model.Update(ctx, req)
|
err = modelService.Model.Update(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteModel 删除配置
|
// DeleteModel 删除配置
|
||||||
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
func (c *model) DeleteModel(ctx context.Context, req *dto.DeleteModelReq) (res *dto.DeleteModelRes, err error) {
|
||||||
err = service.Model.Delete(ctx, req)
|
err = modelService.Model.Delete(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModel 获取配置详情
|
// GetModel 获取配置详情
|
||||||
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
func (c *model) GetModel(ctx context.Context, req *dto.GetModelReq) (res *dto.GetModelRes, err error) {
|
||||||
return service.Model.Get(ctx, req)
|
return modelService.Model.Get(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListModel 配置列表
|
// ListModel 配置列表
|
||||||
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
func (c *model) ListModel(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||||
return service.Model.List(ctx, req)
|
return modelService.Model.List(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
// AutoTune 动态调参(由上层定时任务每小时触发一次)
|
||||||
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
func (c *model) AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||||
return service.AutoTune(ctx, req)
|
return queue.AutoTune(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListType 模型类型列表
|
// ListType 模型类型列表
|
||||||
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
func (c *model) ListType(ctx context.Context, req *dto.ListTypeReq) (res *dto.TypeItem, err error) {
|
||||||
return service.GetModelTypesFromConfig()
|
return modelService.GetModelTypesFromConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListOperator 运营商列表
|
// ListOperator 运营商列表
|
||||||
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
|
func (c *model) ListOperator(ctx context.Context, req *dto.ListOperatorReq) (res *dto.ListOperatorRes, err error) {
|
||||||
return service.GetOperatorList()
|
return modelService.GetOperatorList()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateChatModel 更新是否为聊天模型
|
// UpdateChatModel 更新是否为聊天模型
|
||||||
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
func (c *model) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) (res *dto.UpdateChatModelRes, err error) {
|
||||||
err = service.Model.UpdateChatModel(ctx, req)
|
err = modelService.Model.UpdateChatModel(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIsChatModel 获取当前会话模型
|
// GetIsChatModel 获取当前会话模型
|
||||||
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
func (c *model) GetIsChatModel(ctx context.Context, req *dto.GetIsChatModelReq) (res *dto.GetIsChatModelRes, err error) {
|
||||||
return service.Model.GetIsChatModel(ctx)
|
return modelService.Model.GetIsChatModel(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
statService "model-gateway/service/stat"
|
||||||
|
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type stat struct{}
|
type stat struct{}
|
||||||
@@ -14,5 +14,5 @@ var Stat = new(stat)
|
|||||||
|
|
||||||
// ListModelStat 统计列表
|
// ListModelStat 统计列表
|
||||||
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
func (c *stat) ListModelStat(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
||||||
return service.Stat.List(ctx, req)
|
return statService.Stat.List(ctx, req)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"model-gateway/service/job"
|
||||||
|
taskService "model-gateway/service/task"
|
||||||
|
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type task struct{}
|
type task struct{}
|
||||||
@@ -14,30 +15,30 @@ var Task = new(task)
|
|||||||
|
|
||||||
// CreateTask 根据 modelName 创建异步任务,返回 taskId
|
// CreateTask 根据 modelName 创建异步任务,返回 taskId
|
||||||
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
func (c *task) CreateTask(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||||
return service.Task.Create(ctx, req)
|
return taskService.Task.Create(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTaskResult 获取任务结果(只返回 oss 地址 + state)
|
// GetTaskResult 获取任务结果(只返回 oss 地址 + state)
|
||||||
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
|
func (c *task) GetTaskResult(ctx context.Context, req *dto.GetTaskResultReq) (res *dto.GetTaskResultRes, err error) {
|
||||||
return service.Task.GetResult(ctx, req.TaskID)
|
return taskService.Task.GetResult(ctx, req.TaskID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
|
// GetTaskBatch 批量查询任务(成功任务标记为已下载)
|
||||||
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
func (c *task) GetTaskBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||||
return service.Task.GetBatch(ctx, req)
|
return taskService.Task.GetBatch(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListTask 任务列表分页查询
|
// ListTask 任务列表分页查询
|
||||||
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
func (c *task) ListTask(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||||
return service.Task.List(ctx, req)
|
return taskService.Task.List(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
// RunWork 手动触发一次 worker(由上层定时任务调用)
|
||||||
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
func (c *task) RunWork(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||||
return service.AsyncWorker.RunOnce(ctx, req)
|
return taskService.AsyncWorker.RunOnce(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
// CleanWork 手动触发一次 cleaner(由上层定时任务调用)
|
||||||
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
func (c *task) CleanWork(ctx context.Context, req *dto.CleanWorkReq) (res *dto.CleanWorkRes, err error) {
|
||||||
return service.Cleaner.RunOnce(ctx)
|
return job.Cleaner.RunOnce(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"model-gateway/consts/public"
|
"model-gateway/consts/public"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/model/entity"
|
"model-gateway/model/entity"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
@@ -90,22 +91,28 @@ func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchMode
|
|||||||
|
|
||||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||||
// 基础 SQL
|
|
||||||
sql := `
|
sql := `
|
||||||
SELECT DISTINCT ON (model_name) *
|
SELECT DISTINCT ON (model_name) *
|
||||||
FROM asynch_models
|
FROM asynch_models
|
||||||
WHERE deleted_at IS NULL
|
WHERE deleted_at IS NULL
|
||||||
AND (? = '' OR model_name LIKE ?)
|
AND (? = '' OR model_name LIKE ?)
|
||||||
AND (? = 0 OR model_type = ?)
|
|
||||||
`
|
`
|
||||||
args := []any{
|
args := []any{
|
||||||
req.ModelName, "%" + req.ModelName + "%",
|
req.ModelName, "%" + req.ModelName + "%",
|
||||||
req.ModelType, req.ModelType,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// modelType: 传 6 模糊匹配 6%
|
||||||
|
if req.ModelType > 0 {
|
||||||
|
prefix := strconv.Itoa(req.ModelType)[:1] // 截取第一位
|
||||||
|
sql += ` AND model_type::text LIKE ? `
|
||||||
|
args = append(args, prefix+"%")
|
||||||
|
}
|
||||||
|
|
||||||
if !g.IsEmpty(req.IsPrivate) {
|
if !g.IsEmpty(req.IsPrivate) {
|
||||||
sql += ` AND is_private = ? `
|
sql += ` AND is_private = ? `
|
||||||
args = append(args, req.IsPrivate)
|
args = append(args, req.IsPrivate)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.IsOwner != nil && *req.IsOwner == 0 {
|
if req.IsOwner != nil && *req.IsOwner == 0 {
|
||||||
if req.Enabled != nil && *req.Enabled == 1 {
|
if req.Enabled != nil && *req.Enabled == 1 {
|
||||||
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
|
sql += ` AND creator = ? AND is_owner = ? AND enabled=1 `
|
||||||
@@ -114,9 +121,7 @@ WHERE deleted_at IS NULL
|
|||||||
} else {
|
} else {
|
||||||
sql += ` AND creator = ? AND is_owner = ? `
|
sql += ` AND creator = ? AND is_owner = ? `
|
||||||
}
|
}
|
||||||
|
args = append(args, req.Creator, req.IsOwner)
|
||||||
args = append(args, req.Creator)
|
|
||||||
args = append(args, req.IsOwner)
|
|
||||||
} else if req.IsOwner != nil && *req.IsOwner == 1 {
|
} else if req.IsOwner != nil && *req.IsOwner == 1 {
|
||||||
if req.Enabled != nil && *req.Enabled == 1 {
|
if req.Enabled != nil && *req.Enabled == 1 {
|
||||||
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
|
sql += ` AND ((creator = ? AND is_owner = ? AND enabled=1) OR (is_owner = 0 AND enabled=1)) `
|
||||||
@@ -125,11 +130,9 @@ WHERE deleted_at IS NULL
|
|||||||
} else {
|
} else {
|
||||||
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
|
sql += ` AND ((creator = ? AND is_owner = ?) OR (is_owner = 0 AND enabled=1)) `
|
||||||
}
|
}
|
||||||
args = append(args, req.Creator)
|
args = append(args, req.Creator, req.IsOwner)
|
||||||
args = append(args, req.IsOwner)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 最后拼接排序
|
|
||||||
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
sql += ` ORDER BY model_name, is_owner DESC, created_at DESC`
|
||||||
|
|
||||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||||
|
|||||||
7
main.go
7
main.go
@@ -3,13 +3,14 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
|
"model-gateway/service/job"
|
||||||
|
"model-gateway/service/task"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/controller"
|
"model-gateway/controller"
|
||||||
"model-gateway/service"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.com/red-future/common/http"
|
||||||
"gitea.com/red-future/common/jaeger"
|
"gitea.com/red-future/common/jaeger"
|
||||||
@@ -62,7 +63,7 @@ func startAutoRunner(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if _, err := service.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
if _, err := task.AsyncWorker.RunOnce(ctx, &dto.RunWorkReq{
|
||||||
BatchSize: batchSize,
|
BatchSize: batchSize,
|
||||||
Goroutines: goroutines,
|
Goroutines: goroutines,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -87,7 +88,7 @@ func startAutoRunner(ctx context.Context) {
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
_, _ = service.Cleaner.RunOnce(ctx)
|
_, _ = job.Cleaner.RunOnce(ctx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -10,33 +10,33 @@ import (
|
|||||||
// CreateModelReq 添加模型配置
|
// CreateModelReq 添加模型配置
|
||||||
type CreateModelReq struct {
|
type CreateModelReq struct {
|
||||||
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
g.Meta `path:"/createModel" method:"post" tags:"模型管理" summary:"创建模型配置" dc:"添加新的模型配置"`
|
||||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
|
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"模型名称(唯一标识)"`
|
||||||
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
|
ModelType int `p:"modelType" json:"modelType" v:"required#modelType不能为空" dc:"模型类型:1-文本生成 2-图像生成 3-语音 4-视频 5-多模态"`
|
||||||
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"`
|
BaseURL string `p:"baseUrl" json:"baseUrl" v:"required#baseUrl不能为空" dc:"模型服务基础地址(如 gateway(s)://host:port)"`
|
||||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(默认POST)"`
|
||||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"`
|
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(支持多个,逗号分隔),示例:Authorization:Bearer xxx,Content-Type:application/json"`
|
||||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||||
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
Enabled *int `p:"enabled" json:"enabled" v:"in:0,1#启用参数只能为0或1" dc:"是否启用:0-禁用,1-启用(默认1)"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证"`
|
||||||
Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON),用于前端渲染配置项"`
|
||||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求映射"`
|
||||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回映射"`
|
||||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体"`
|
||||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(默认10)"`
|
||||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(默认1000)"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒,默认600)"`
|
||||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"`
|
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒,默认600)"`
|
||||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(默认3)"`
|
||||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"`
|
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒,默认600)"`
|
||||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"任务完成后自动清理时间(秒,默认86400)"`
|
||||||
Remark string `p:"remark" json:"remark" dc:"备注说明"`
|
Remark string `p:"remark" json:"remark" dc:"备注说明"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateModelRes struct {
|
type CreateModelRes struct {
|
||||||
@@ -45,34 +45,34 @@ type CreateModelRes struct {
|
|||||||
|
|
||||||
type UpdateModelReq struct {
|
type UpdateModelReq struct {
|
||||||
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
g.Meta `path:"/updateModel" method:"put" tags:"模型管理" summary:"更新模型配置" dc:"更新指定ID的模型配置"`
|
||||||
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
ID int64 `p:"id" json:"id" v:"required#id不能为空" dc:"配置ID"`
|
||||||
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
ModelName string `p:"modelName" json:"modelName" dc:"模型名称(唯一标识)"`
|
||||||
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"`
|
ModelType int `p:"modelType" json:"modelType" dc:"模型类型ID列表(逗号分隔)(可选更新)"`
|
||||||
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
|
BaseURL string `p:"baseUrl" json:"baseUrl" dc:"模型服务基础地址"`
|
||||||
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"`
|
HttpMethod string `p:"httpMethod" json:"httpMethod" dc:"请求方式:GET/POST(可选更新)"`
|
||||||
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
|
HeadMsg string `p:"headMsg" json:"headMsg" dc:"请求头绑定(可选更新)"`
|
||||||
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
|
ApiKey string `p:"apiKey" json:"apiKey" dc:"调用凭证/密钥,用于模型认证(可选更新)"`
|
||||||
Form map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"`
|
Form []map[string]any `p:"form" json:"form" dc:"动态表单配置(JSON)(可选更新)"`
|
||||||
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
RequestMapping map[string]any `p:"requestMapping" json:"requestMapping" dc:"请求参数映射(可选更新)"`
|
||||||
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
ResponseMapping map[string]any `p:"responseMapping" json:"responseMapping" dc:"返回参数映射(可选更新)"`
|
||||||
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
ResponseBody map[string]any `p:"responseBody" json:"responseBody" dc:"返回主体(可选更新)"`
|
||||||
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
ResponseTokenField string `p:"responseTokenField" json:"responseTokenField" dc:"响应中消耗token的字段映射"`
|
||||||
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
Enabled *int `p:"enabled" json:"enabled" dc:"是否启用:0-禁用,1-启用(可选更新)"`
|
||||||
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
IsPrivate *int `p:"isPrivate" json:"isPrivate" v:"in:0,1#私有化参数只能为0或1" dc:"是否私有化:0-私有(默认) 1-公共"`
|
||||||
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
IsChatModel *int `p:"isChatModel" json:"isChatModel" v:"in:0,1#对话模型参数只能为0或1" dc:"是否为对话模型:0-否,1-是(默认0)"`
|
||||||
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
IsOwner *int `p:"isOwner" json:"isOwner" v:"in:0,1#是否为所有者参数只能为0或1" dc:"是否为所有者:0-否,1-是(默认0)"`
|
||||||
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
OperatorName string `p:"operatorName" json:"operatorName" v:"required#operatorName不能为空" dc:"运营商名称"`
|
||||||
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
TokenConfig map[string]any `p:"tokenConfig" json:"tokenConfig" dc:"token计算配置"`
|
||||||
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
ExtendMapping map[string]any `p:"extendMapping" json:"extendMapping" dc:"附加映射"`
|
||||||
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
QueryConfig map[string]any `p:"queryConfig" json:"queryConfig" dc:"结果配置"`
|
||||||
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
MaxConcurrency int `p:"maxConcurrency" json:"maxConcurrency" dc:"最大并发数(可选更新)"`
|
||||||
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
QueueLimit int `p:"queueLimit" json:"queueLimit" dc:"排队队列上限(可选更新)"`
|
||||||
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
TimeoutSeconds int `p:"timeoutSeconds" json:"timeoutSeconds" dc:"请求超时时间(秒)(可选更新)"`
|
||||||
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
|
ExpectedSeconds int `p:"expectedSeconds" json:"expectedSeconds" dc:"模型预计执行时间(秒)(可选更新)"`
|
||||||
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
|
RetryTimes int `p:"retryTimes" json:"retryTimes" dc:"失败重试次数(可选更新)"`
|
||||||
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
|
RetryQueueMaxSeconds int `p:"retryQueueMaxSeconds" json:"retryQueueMaxSeconds" dc:"失败重试最大排队时间(秒)(可选更新)"`
|
||||||
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
|
AutoCleanSeconds int `p:"autoCleanSeconds" json:"autoCleanSeconds" dc:"自动清理间隔(秒)(可选更新)"`
|
||||||
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
Remark string `p:"remark" json:"remark" dc:"备注说明(可选更新)"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateModelRes struct {
|
type UpdateModelRes struct {
|
||||||
@@ -166,3 +166,20 @@ type GetIsChatModelReq struct {
|
|||||||
type GetIsChatModelRes struct {
|
type GetIsChatModelRes struct {
|
||||||
Model any `json:"model" dc:"模型详情"`
|
Model any `json:"model" dc:"模型详情"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NodeFormField 节点表单
|
||||||
|
type NodeFormField struct {
|
||||||
|
Value any `json:"value" dc:"字段值"`
|
||||||
|
Field string `json:"field" dc:"字段标识"`
|
||||||
|
Label string `json:"label" dc:"字段标签"`
|
||||||
|
Type string `json:"type" dc:"字段类型"`
|
||||||
|
Required bool `json:"required" dc:"是否必填"`
|
||||||
|
Default any `json:"default,omitempty" dc:"默认值"`
|
||||||
|
Options []SelectOption `json:"options" dc:"下拉选项列表"`
|
||||||
|
FieldConstraint any `json:"fieldConstraint" dc:"字段约束"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelectOption struct {
|
||||||
|
Label string `json:"label" dc:"选项标签"`
|
||||||
|
Value string `json:"value" dc:"选项值"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -69,32 +69,32 @@ var AsynchModelCol = asynchModelCol{
|
|||||||
// AsynchModel 异步模型配置
|
// AsynchModel 异步模型配置
|
||||||
type AsynchModel struct {
|
type AsynchModel struct {
|
||||||
beans.SQLBaseDO `orm:",inline"`
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
ModelName string `orm:"model_name" json:"modelName"`
|
ModelName string `orm:"model_name" json:"modelName"`
|
||||||
ModelType int `orm:"model_type" json:"modelType"`
|
ModelType int `orm:"model_type" json:"modelType"`
|
||||||
BaseURL string `orm:"base_url" json:"baseUrl"`
|
BaseURL string `orm:"base_url" json:"baseUrl"`
|
||||||
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
HttpMethod string `orm:"http_method" json:"httpMethod"`
|
||||||
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
HeadMsg string `orm:"head_msg" json:"headMsg"`
|
||||||
Form map[string]any `orm:"form_json" json:"form"`
|
Form []map[string]any `orm:"form_json" json:"form"`
|
||||||
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
RequestMapping map[string]any `orm:"request_mapping" json:"requestMapping"`
|
||||||
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
ResponseMapping map[string]any `orm:"response_mapping" json:"responseMapping"`
|
||||||
ResponseBody map[string]any `orm:"response_body" json:"responseBody"`
|
ResponseBody map[string]any `orm:"response_body" json:"responseBody"`
|
||||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||||
Prompt string `orm:"prompt" json:"prompt"`
|
Prompt string `orm:"prompt" json:"prompt"`
|
||||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||||
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
IsChatModel *int `orm:"is_chat_model" json:"isChatModel"`
|
||||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||||
Enabled *int `orm:"enabled" json:"enabled"`
|
Enabled *int `orm:"enabled" json:"enabled"`
|
||||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||||
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
QueueLimit int `orm:"queue_limit" json:"queueLimit"`
|
||||||
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
TimeoutSeconds int `orm:"timeout_seconds" json:"timeoutSeconds"`
|
||||||
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
ExpectedSeconds int `orm:"expected_seconds" json:"expectedSeconds"`
|
||||||
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
RetryTimes int `orm:"retry_times" json:"retryTimes"`
|
||||||
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
RetryQueueMaxSeconds int `orm:"retry_queue_max_seconds" json:"retryQueueMaxSeconds"`
|
||||||
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
AutoCleanSeconds int `orm:"auto_clean_seconds" json:"autoCleanSeconds"`
|
||||||
Remark string `orm:"remark" json:"remark"`
|
Remark string `orm:"remark" json:"remark"`
|
||||||
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
IsOwner *int `json:"isOwner" orm:"is_owner"`
|
||||||
OperatorName string `orm:"operator_name" json:"operatorName"`
|
OperatorName string `orm:"operator_name" json:"operatorName"`
|
||||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
package service
|
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
package service
|
package job
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
|
"model-gateway/service/queue"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -20,32 +21,32 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
|||||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
g.Log().Errorf(ctx, "[清理] 查询已下载过期任务失败: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range expired {
|
for _, t := range expired {
|
||||||
_ = os.Remove(t.TmpFile)
|
_ = os.Remove(t.TmpFile)
|
||||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
g.Log().Infof(ctx, "[清理] 已下载过期任务清理完成, count=%d", len(expired))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 超时任务标失败
|
// 2) 超时任务标失败
|
||||||
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
|
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, 200)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
|
g.Log().Errorf(ctx, "[清理] 查询超时任务失败: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range list {
|
for _, t := range list {
|
||||||
t.ErrorMsg = "任务超时自动失败"
|
t.ErrorMsg = "任务超时自动失败"
|
||||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
|
g.Log().Infof(ctx, "[清理] 超时任务处理完成, count=%d", len(list))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
||||||
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
g.Log().Errorf(ctx, "[清理] 查询可重试任务失败: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range retryable {
|
for _, t := range retryable {
|
||||||
// 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试
|
// 失败任务重新入队(state=3 -> 0)前,先严格占用 queue_limit slot;占用失败则留在失败态,下一轮再尝试
|
||||||
@@ -54,9 +55,9 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
|||||||
if err != nil || m == nil {
|
if err != nil || m == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
limit := GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
limit := queue.GetRuntimeQueueLimit(ctx, t.ModelName, m.QueueLimit)
|
||||||
if limit > 0 {
|
if limit > 0 {
|
||||||
ok, _ := AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
ok, _ := queue.AcquireQueueSlot(ctx, t.ModelName, t.TaskID, limit, m.ExpectedSeconds)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -76,21 +77,21 @@ func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error
|
|||||||
}
|
}
|
||||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
g.Log().Infof(ctx, "[清理] 可重试任务重新入队完成, count=%d", len(retryable))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
|
// 4) 超过重试次数仍失败(state=3)的任务:硬删除
|
||||||
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
g.Log().Errorf(ctx, "[清理] 查询重试耗尽任务失败: %v", err)
|
||||||
} else {
|
} else {
|
||||||
for _, t := range exhausted {
|
for _, t := range exhausted {
|
||||||
_ = os.Remove(t.TmpFile)
|
_ = os.Remove(t.TmpFile)
|
||||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
g.Log().Infof(ctx, "[清理] 重试耗尽任务清理完成, count=%d", len(exhausted))
|
||||||
}
|
}
|
||||||
return &dto.CleanWorkRes{
|
return &dto.CleanWorkRes{
|
||||||
Ok: true,
|
Ok: true,
|
||||||
254
service/model/model_service.go
Normal file
254
service/model/model_service.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"model-gateway/common/util"
|
||||||
|
"model-gateway/consts/public"
|
||||||
|
"model-gateway/dao"
|
||||||
|
"model-gateway/model/dto"
|
||||||
|
"model-gateway/model/entity"
|
||||||
|
"model-gateway/service/gateway"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"gitea.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Model = &modelService{}
|
||||||
|
|
||||||
|
type modelService struct{}
|
||||||
|
|
||||||
|
// Create 创建模型
|
||||||
|
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (*dto.CreateModelRes, error) {
|
||||||
|
// 1)如果设为会话模型,先把该用户旧会话模型取消
|
||||||
|
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||||
|
if err := s.clearUserChatModel(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2)判断是否超管,决定 isOwner
|
||||||
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
|
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||||
|
req.IsOwner = gconv.PtrInt(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3)入库
|
||||||
|
id, err := dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dto.CreateModelRes{ID: id}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新模型配置
|
||||||
|
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||||
|
// 1)会话模型唯一性校验
|
||||||
|
if req.IsChatModel != nil && *req.IsChatModel == 1 {
|
||||||
|
if err := s.checkChatModelUnique(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 2)超管创建/普通用户更新
|
||||||
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
|
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||||
|
req.IsOwner = gconv.PtrInt(0)
|
||||||
|
_, err := dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 3)跨租户判断:超管的模型不允许直接修改,走插入新记录
|
||||||
|
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if model.TenantId == 1 {
|
||||||
|
_, err = dao.Model.Insert(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = dao.Model.Update(ctx, util.ConvertTo[entity.AsynchModel](req))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除模型
|
||||||
|
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||||
|
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取模型详情
|
||||||
|
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if g.IsEmpty(req.ID) {
|
||||||
|
req.Creator = user.UserName
|
||||||
|
}
|
||||||
|
modelReq := new(entity.AsynchModel)
|
||||||
|
err = gconv.Struct(req, modelReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
model, err := dao.Model.Get(ctx, modelReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dto.GetModelRes{
|
||||||
|
Model: model,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 获取模型列表
|
||||||
|
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (*dto.ListModelRes, error) {
|
||||||
|
// 1)判断超管
|
||||||
|
req.IsOwner = gconv.PtrInt(1)
|
||||||
|
if isAdmin, _ := gateway.IsSuperAdmin(ctx); isAdmin {
|
||||||
|
req.IsOwner = gconv.PtrInt(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2)获取当前用户
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Creator = user.UserName
|
||||||
|
|
||||||
|
// 3)查询
|
||||||
|
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dto.ListModelRes{List: models, Total: total}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateChatModel 设置会话模型
|
||||||
|
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||||
|
// 1)校验新模型存在
|
||||||
|
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||||
|
})
|
||||||
|
if err != nil || newModel == nil {
|
||||||
|
return errors.New("新会话模型不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2)获取当前用户的会话模型
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3)事务:取消旧的 + 设置新的
|
||||||
|
return gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||||
|
if !g.IsEmpty(currentModel) {
|
||||||
|
if currentModel.ModelType != public.ModelTypeInference {
|
||||||
|
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||||
|
}
|
||||||
|
if currentModel.Id != req.Id {
|
||||||
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||||
|
IsChatModel: gconv.PtrInt(0),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||||
|
IsChatModel: gconv.PtrInt(1),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIsChatModel 获取当前用户会话模型
|
||||||
|
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
|
if err != nil || model == nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &dto.GetIsChatModelRes{Model: model}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 辅助方法 ====================
|
||||||
|
|
||||||
|
// clearUserChatModel 清除当前用户旧会话模型
|
||||||
|
func (s *modelService) clearUserChatModel(ctx context.Context) error {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
|
if err != nil || model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||||
|
IsChatModel: gconv.PtrInt(0),
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkChatModelUnique 校验用户是否已有会话模型
|
||||||
|
func (s *modelService) checkChatModelUnique(ctx context.Context) error {
|
||||||
|
user, err := utils.GetUserInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
|
SQLBaseDO: beans.SQLBaseDO{Creator: user.UserName},
|
||||||
|
IsChatModel: new(1),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if model != nil {
|
||||||
|
return errors.New("用户已存在会话模型")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||||
|
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||||
|
// 返回副本,避免外部修改
|
||||||
|
types := make(map[int]string, len(public.ModelTypeName))
|
||||||
|
for k, v := range public.ModelTypeName {
|
||||||
|
types[k] = v
|
||||||
|
}
|
||||||
|
return &dto.TypeItem{
|
||||||
|
Type: types,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOperatorList 获取运营商列表
|
||||||
|
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
||||||
|
return &dto.ListOperatorRes{
|
||||||
|
List: public.OperatorList,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -1,469 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
|
||||||
// 示例:
|
|
||||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
|
||||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
|
||||||
//
|
|
||||||
// 说明:
|
|
||||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
|
||||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
|
||||||
func parseHeadMsgHeaders(headMsg string) map[string]string {
|
|
||||||
headMsg = strings.TrimSpace(headMsg)
|
|
||||||
if headMsg == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := map[string]string{}
|
|
||||||
parts := strings.Split(headMsg, ",")
|
|
||||||
for _, p := range parts {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
|
||||||
if strings.Contains(p, ":") {
|
|
||||||
kv := strings.SplitN(p, ":", 2)
|
|
||||||
k := strings.TrimSpace(kv[0])
|
|
||||||
v := strings.TrimSpace(kv[1])
|
|
||||||
v = strings.Trim(v, "\"")
|
|
||||||
if k != "" && v != "" {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(p, "=") {
|
|
||||||
kv := strings.SplitN(p, "=", 2)
|
|
||||||
k := strings.TrimSpace(kv[0])
|
|
||||||
v := strings.TrimSpace(kv[1])
|
|
||||||
v = strings.Trim(v, "\"")
|
|
||||||
if k != "" && v != "" {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(out) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func payloadToQuery(payload any) (url.Values, error) {
|
|
||||||
if payload == nil {
|
|
||||||
return url.Values{}, nil
|
|
||||||
}
|
|
||||||
// 统一转成 map[string]any
|
|
||||||
b, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m := map[string]any{}
|
|
||||||
if err := json.Unmarshal(b, &m); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
q := url.Values{}
|
|
||||||
for k, v := range m {
|
|
||||||
if v == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 复杂类型直接 json 字符串化
|
|
||||||
switch vv := v.(type) {
|
|
||||||
case string:
|
|
||||||
q.Set(k, vv)
|
|
||||||
case float64, bool, int, int64, uint64:
|
|
||||||
q.Set(k, fmt.Sprintf("%v", vv))
|
|
||||||
default:
|
|
||||||
bs, _ := json.Marshal(v)
|
|
||||||
q.Set(k, string(bs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return q, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvokeModel 调用模型服务,返回二进制结果
|
|
||||||
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)。
|
|
||||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
|
||||||
if m == nil || m.BaseURL == "" {
|
|
||||||
return nil, fmt.Errorf("模型配置不完整")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ 新增:请求参数映射 ============
|
|
||||||
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
url := strings.TrimRight(m.BaseURL, "/")
|
|
||||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
|
||||||
if timeout <= 0 {
|
|
||||||
timeout = 60 * time.Second
|
|
||||||
}
|
|
||||||
client := &http.Client{Timeout: timeout}
|
|
||||||
|
|
||||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
|
||||||
if method == "" {
|
|
||||||
method = http.MethodPost
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
req *http.Request
|
|
||||||
)
|
|
||||||
switch method {
|
|
||||||
case http.MethodGet:
|
|
||||||
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(q) > 0 {
|
|
||||||
if strings.Contains(url, "?") {
|
|
||||||
url = url + "&" + q.Encode()
|
|
||||||
} else {
|
|
||||||
url = url + "?" + q.Encode()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
||||||
default:
|
|
||||||
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key)
|
|
||||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
|
||||||
req.Header.Set(hk, hv)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。
|
|
||||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
|
||||||
req.Header.Set(hk, hv)
|
|
||||||
}
|
|
||||||
|
|
||||||
if method != http.MethodGet {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
b, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
msg := string(b)
|
|
||||||
if len(msg) > 2000 {
|
|
||||||
msg = msg[:2000]
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============ 新增:响应参数映射 ============
|
|
||||||
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
|
|
||||||
if err != nil {
|
|
||||||
// 响应映射失败不阻塞,返回原始数据
|
|
||||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
// =========================================
|
|
||||||
|
|
||||||
return mappedResponse, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//// InvokeModel 调用模型服务,返回二进制结果
|
|
||||||
//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
|
||||||
// if m == nil || m.BaseURL == "" {
|
|
||||||
// return nil, fmt.Errorf("模型配置不完整")
|
|
||||||
// }
|
|
||||||
// // 请求参数映射
|
|
||||||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
|
||||||
// }
|
|
||||||
// // 合并请求头
|
|
||||||
// headers := util.ForwardHeaders(ctx)
|
|
||||||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
|
||||||
// headers[hk] = hv
|
|
||||||
// }
|
|
||||||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
|
||||||
// headers[hk] = hv
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 设置超时
|
|
||||||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
|
||||||
// if timeout <= 0 {
|
|
||||||
// timeout = 600 * time.Second
|
|
||||||
// }
|
|
||||||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
|
||||||
// defer cancel()
|
|
||||||
//
|
|
||||||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
|
||||||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
|
||||||
// if method == "" {
|
|
||||||
// method = http.MethodPost
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// var respBytes []byte
|
|
||||||
//
|
|
||||||
// switch method {
|
|
||||||
// case http.MethodGet:
|
|
||||||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
|
||||||
// default:
|
|
||||||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
|
||||||
// }
|
|
||||||
// if err != nil {
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
// // 响应参数映射
|
|
||||||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
|
||||||
// if err != nil {
|
|
||||||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
|
||||||
// return respBytes, nil
|
|
||||||
// }
|
|
||||||
// return mappedResponse, nil
|
|
||||||
//}
|
|
||||||
|
|
||||||
// ============================================
|
|
||||||
// 映射相关函数
|
|
||||||
// ============================================
|
|
||||||
|
|
||||||
// mapRequestPayload 将标准请求映射为模型特定格式
|
|
||||||
func mapRequestPayload(mappingAny any, payload any) (any, error) {
|
|
||||||
// 1. 解析请求映射配置(值是any类型,支持bool、number等)
|
|
||||||
mapping, err := parseRequestMapping(mappingAny)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果没有映射配置,直接返回原始payload
|
|
||||||
if len(mapping) == 0 {
|
|
||||||
return payload, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 将payload转为map
|
|
||||||
var payloadMap map[string]any
|
|
||||||
switch v := payload.(type) {
|
|
||||||
case map[string]any:
|
|
||||||
payloadMap = v
|
|
||||||
case []map[string]any:
|
|
||||||
// 如果传进来的是纯messages数组,包装成标准格式
|
|
||||||
payloadMap = map[string]any{
|
|
||||||
"messages": v,
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// 通过JSON转换
|
|
||||||
jsonBytes, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("序列化payload失败: %w", err)
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("反序列化payload失败: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 用数据库固定参数覆盖/补充
|
|
||||||
for key, value := range mapping {
|
|
||||||
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
|
|
||||||
payloadMap[key] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return payloadMap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapResponsePayload 将模型响应映射为标准格式
|
|
||||||
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
|
|
||||||
mapping, err := parseResponseMapping(mappingAny)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(mapping) == 0 {
|
|
||||||
return responseBytes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
responseStr := string(responseBytes)
|
|
||||||
resultStr := `{}`
|
|
||||||
|
|
||||||
for standardField, modelPath := range mapping {
|
|
||||||
value := gjson.Get(responseStr, modelPath)
|
|
||||||
if !value.Exists() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(resultStr), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseRequestMapping(mappingAny any) (map[string]any, error) {
|
|
||||||
if mappingAny == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make(map[string]any)
|
|
||||||
|
|
||||||
switch v := mappingAny.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
// 尝试转成 map
|
|
||||||
if m := v.Map(); m != nil {
|
|
||||||
for k, val := range m {
|
|
||||||
result[k] = val
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
// 尝试转成 string
|
|
||||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
|
||||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
// =======================================================
|
|
||||||
|
|
||||||
case map[string]interface{}:
|
|
||||||
result = v
|
|
||||||
|
|
||||||
case string:
|
|
||||||
if v == "" || v == "{}" || v == "null" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case []byte:
|
|
||||||
if len(v) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(v, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
jsonBytes, err := json.Marshal(mappingAny)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析映射配置失败: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseResponseMapping 解析响应映射配置
|
|
||||||
// 返回值类型为 map[string]string,值都是JSON路径字符串
|
|
||||||
func parseResponseMapping(mappingAny any) (map[string]string, error) {
|
|
||||||
if mappingAny == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mapping := make(map[string]string)
|
|
||||||
|
|
||||||
switch v := mappingAny.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if m := v.Map(); m != nil {
|
|
||||||
for k, val := range m {
|
|
||||||
if strVal, ok := val.(string); ok {
|
|
||||||
mapping[k] = strVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mapping, nil
|
|
||||||
}
|
|
||||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
|
||||||
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
|
||||||
}
|
|
||||||
return mapping, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
case string:
|
|
||||||
if v == "" || v == "{}" || v == "null" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case map[string]interface{}:
|
|
||||||
// 数据库JSONB直接返回的map
|
|
||||||
for k, val := range v {
|
|
||||||
if strVal, ok := val.(string); ok {
|
|
||||||
mapping[k] = strVal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case []byte:
|
|
||||||
if len(v) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(v, &mapping); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
jsonBytes, err := json.Marshal(mappingAny)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapping, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isEmptyValue 判断值是否为空
|
|
||||||
func isEmptyValue(v any) bool {
|
|
||||||
if v == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
switch val := v.(type) {
|
|
||||||
case string:
|
|
||||||
return val == ""
|
|
||||||
case []any:
|
|
||||||
return len(val) == 0
|
|
||||||
case map[string]any:
|
|
||||||
return len(val) == 0
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"model-gateway/consts/public"
|
|
||||||
"model-gateway/dao"
|
|
||||||
"model-gateway/model/dto"
|
|
||||||
"model-gateway/model/entity"
|
|
||||||
"model-gateway/service/gateway"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
|
||||||
"gitea.com/red-future/common/utils"
|
|
||||||
"github.com/gogf/gf/v2/database/gdb"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
var Model = &modelService{}
|
|
||||||
|
|
||||||
type modelService struct{}
|
|
||||||
|
|
||||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
|
||||||
// 获取当前会话模型
|
|
||||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
|
||||||
var user *beans.User
|
|
||||||
user, err = utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// 获取当前用户会话模型
|
|
||||||
var model *entity.AsynchModel
|
|
||||||
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
|
||||||
Creator: user.UserName,
|
|
||||||
},
|
|
||||||
IsChatModel: new(1),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// 如果有会话模型,那就改变为 0
|
|
||||||
if model != nil {
|
|
||||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
|
||||||
IsChatModel: gconv.PtrInt(0),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
|
||||||
admin, err := gateway.IsSuperAdmin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if admin {
|
|
||||||
req.IsOwner = gconv.PtrInt(0)
|
|
||||||
}
|
|
||||||
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
ModelType: req.ModelType,
|
|
||||||
BaseURL: req.BaseURL,
|
|
||||||
HttpMethod: req.HttpMethod,
|
|
||||||
HeadMsg: req.HeadMsg,
|
|
||||||
Form: req.Form,
|
|
||||||
RequestMapping: req.RequestMapping,
|
|
||||||
ResponseMapping: req.ResponseMapping,
|
|
||||||
ResponseBody: req.ResponseBody,
|
|
||||||
ResponseTokenField: req.ResponseTokenField,
|
|
||||||
IsPrivate: req.IsPrivate,
|
|
||||||
IsChatModel: req.IsChatModel,
|
|
||||||
ApiKey: req.ApiKey,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
MaxConcurrency: req.MaxConcurrency,
|
|
||||||
QueueLimit: req.QueueLimit,
|
|
||||||
TimeoutSeconds: req.TimeoutSeconds,
|
|
||||||
ExpectedSeconds: req.ExpectedSeconds,
|
|
||||||
RetryTimes: req.RetryTimes,
|
|
||||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
|
||||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
|
||||||
Remark: req.Remark,
|
|
||||||
IsOwner: req.IsOwner,
|
|
||||||
OperatorName: req.OperatorName,
|
|
||||||
TokenConfig: req.TokenConfig,
|
|
||||||
ExtendMapping: req.ExtendMapping,
|
|
||||||
QueryConfig: req.QueryConfig,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dto.CreateModelRes{ID: id}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
|
||||||
//根据当前 isChatModel 来判断是否更新模型
|
|
||||||
if req.IsChatModel == gconv.PtrInt(1) {
|
|
||||||
user, err := utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 获取当前用户会话模型
|
|
||||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
|
||||||
Creator: user.UserName,
|
|
||||||
},
|
|
||||||
IsChatModel: new(1),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if model != nil {
|
|
||||||
return errors.New("用户已存在会话模型,不能创建")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
|
||||||
admin, err := gateway.IsSuperAdmin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if admin {
|
|
||||||
req.IsOwner = gconv.PtrInt(0)
|
|
||||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
ModelType: req.ModelType,
|
|
||||||
BaseURL: req.BaseURL,
|
|
||||||
HttpMethod: req.HttpMethod,
|
|
||||||
HeadMsg: req.HeadMsg,
|
|
||||||
Form: req.Form,
|
|
||||||
RequestMapping: req.RequestMapping,
|
|
||||||
ResponseMapping: req.ResponseMapping,
|
|
||||||
ResponseBody: req.ResponseBody,
|
|
||||||
ResponseTokenField: req.ResponseTokenField,
|
|
||||||
IsPrivate: req.IsPrivate,
|
|
||||||
IsChatModel: req.IsChatModel,
|
|
||||||
ApiKey: req.ApiKey,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
MaxConcurrency: req.MaxConcurrency,
|
|
||||||
QueueLimit: req.QueueLimit,
|
|
||||||
TimeoutSeconds: req.TimeoutSeconds,
|
|
||||||
ExpectedSeconds: req.ExpectedSeconds,
|
|
||||||
RetryTimes: req.RetryTimes,
|
|
||||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
|
||||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
|
||||||
Remark: req.Remark,
|
|
||||||
IsOwner: req.IsOwner,
|
|
||||||
OperatorName: req.OperatorName,
|
|
||||||
TokenConfig: req.TokenConfig,
|
|
||||||
ExtendMapping: req.ExtendMapping,
|
|
||||||
QueryConfig: req.QueryConfig,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
|
||||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if model.TenantId == 1 {
|
|
||||||
insertDto := new(dto.CreateModelReq)
|
|
||||||
err = gconv.Struct(req, insertDto)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
ModelType: req.ModelType,
|
|
||||||
BaseURL: req.BaseURL,
|
|
||||||
HttpMethod: req.HttpMethod,
|
|
||||||
HeadMsg: req.HeadMsg,
|
|
||||||
Form: req.Form,
|
|
||||||
RequestMapping: req.RequestMapping,
|
|
||||||
ResponseMapping: req.ResponseMapping,
|
|
||||||
ResponseBody: req.ResponseBody,
|
|
||||||
ResponseTokenField: req.ResponseTokenField,
|
|
||||||
IsPrivate: req.IsPrivate,
|
|
||||||
IsChatModel: req.IsChatModel,
|
|
||||||
ApiKey: req.ApiKey,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
MaxConcurrency: req.MaxConcurrency,
|
|
||||||
QueueLimit: req.QueueLimit,
|
|
||||||
TimeoutSeconds: req.TimeoutSeconds,
|
|
||||||
ExpectedSeconds: req.ExpectedSeconds,
|
|
||||||
RetryTimes: req.RetryTimes,
|
|
||||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
|
||||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
|
||||||
Remark: req.Remark,
|
|
||||||
IsOwner: req.IsOwner,
|
|
||||||
OperatorName: req.OperatorName,
|
|
||||||
TokenConfig: req.TokenConfig,
|
|
||||||
ExtendMapping: req.ExtendMapping,
|
|
||||||
QueryConfig: req.QueryConfig,
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
ModelType: req.ModelType,
|
|
||||||
BaseURL: req.BaseURL,
|
|
||||||
HttpMethod: req.HttpMethod,
|
|
||||||
HeadMsg: req.HeadMsg,
|
|
||||||
Form: req.Form,
|
|
||||||
RequestMapping: req.RequestMapping,
|
|
||||||
ResponseMapping: req.ResponseMapping,
|
|
||||||
ResponseBody: req.ResponseBody,
|
|
||||||
ResponseTokenField: req.ResponseTokenField,
|
|
||||||
IsPrivate: req.IsPrivate,
|
|
||||||
IsChatModel: req.IsChatModel,
|
|
||||||
ApiKey: req.ApiKey,
|
|
||||||
Enabled: req.Enabled,
|
|
||||||
MaxConcurrency: req.MaxConcurrency,
|
|
||||||
QueueLimit: req.QueueLimit,
|
|
||||||
TimeoutSeconds: req.TimeoutSeconds,
|
|
||||||
ExpectedSeconds: req.ExpectedSeconds,
|
|
||||||
RetryTimes: req.RetryTimes,
|
|
||||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
|
||||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
|
||||||
Remark: req.Remark,
|
|
||||||
IsOwner: req.IsOwner,
|
|
||||||
OperatorName: req.OperatorName,
|
|
||||||
TokenConfig: req.TokenConfig,
|
|
||||||
ExtendMapping: req.ExtendMapping,
|
|
||||||
QueryConfig: req.QueryConfig,
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
|
||||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
|
||||||
user, err := utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
|
||||||
Id: req.ID,
|
|
||||||
Creator: user.UserName,
|
|
||||||
},
|
|
||||||
ModelName: req.ModelName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &dto.GetModelRes{
|
|
||||||
Model: model,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
|
||||||
var models []*entity.AsynchModel
|
|
||||||
|
|
||||||
req.IsOwner = gconv.PtrInt(1)
|
|
||||||
admin, err := gateway.IsSuperAdmin(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if admin {
|
|
||||||
req.IsOwner = gconv.PtrInt(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
var user *beans.User
|
|
||||||
user, err = utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Creator = user.UserName
|
|
||||||
|
|
||||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return &dto.ListModelRes{
|
|
||||||
List: models,
|
|
||||||
Total: total,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
|
||||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
|
||||||
// 返回副本,避免外部修改
|
|
||||||
types := make(map[int]string, len(public.ModelTypeName))
|
|
||||||
for k, v := range public.ModelTypeName {
|
|
||||||
types[k] = v
|
|
||||||
}
|
|
||||||
return &dto.TypeItem{
|
|
||||||
Type: types,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOperatorList 获取运营商列表
|
|
||||||
func GetOperatorList() (res *dto.ListOperatorRes, err error) {
|
|
||||||
return &dto.ListOperatorRes{
|
|
||||||
List: public.OperatorList,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
|
||||||
// 校验新会话模型是否存在
|
|
||||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if newModel == nil {
|
|
||||||
return errors.New("新会话模型不存在")
|
|
||||||
}
|
|
||||||
var user *beans.User
|
|
||||||
user, err = utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 获取当前用户会话模型
|
|
||||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
|
||||||
Creator: user.UserName,
|
|
||||||
},
|
|
||||||
IsChatModel: new(1),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
|
||||||
if !g.IsEmpty(currentModel) {
|
|
||||||
if currentModel.ModelType != public.ModelTypeInference {
|
|
||||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
|
||||||
if currentModel.Id != req.Id {
|
|
||||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
|
||||||
IsChatModel: gconv.PtrInt(0),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置当前为会话模型(设为1)
|
|
||||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
|
||||||
IsChatModel: gconv.PtrInt(1),
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
|
||||||
user, err := utils.GetUserInfo(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
|
||||||
SQLBaseDO: beans.SQLBaseDO{
|
|
||||||
Creator: user.UserName,
|
|
||||||
},
|
|
||||||
IsChatModel: new(1),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if model == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return &dto.GetIsChatModelRes{
|
|
||||||
Model: model,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
// 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。
|
// 上层每小时调用 /model/autoTune 写入运行时值;Worker/CreateTask 读取运行时值生效。
|
||||||
|
|
||||||
const (
|
const (
|
||||||
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
runtimeMaxCKeyPrefix = "asynch:runtime:max_concurrency:" // + model_name
|
||||||
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
runtimeQueueKeyPrefix = "asynch:runtime:queue_limit:" // + model_name
|
||||||
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
runtimeTTLSeconds = 2 * 3600 // 2小时,避免一次调参失败导致立即回退
|
||||||
)
|
)
|
||||||
|
|
||||||
func runtimeMaxConcurrencyKey(modelName string) string {
|
func runtimeMaxConcurrencyKey(modelName string) string {
|
||||||
@@ -80,4 +80,3 @@ func clampInt(v, minV, maxV int) int {
|
|||||||
}
|
}
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package queue
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -34,7 +34,8 @@ end
|
|||||||
return 1
|
return 1
|
||||||
`
|
`
|
||||||
|
|
||||||
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
// AcquireSemaphore 获取并发令牌
|
||||||
|
func AcquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||||
if max <= 0 {
|
if max <= 0 {
|
||||||
// 不限制
|
// 不限制
|
||||||
return true, nil
|
return true, nil
|
||||||
@@ -49,8 +50,8 @@ func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64
|
|||||||
return gconv.Int(r) == 1, nil
|
return gconv.Int(r) == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func releaseSemaphore(ctx context.Context, key string) error {
|
// ReleaseSemaphore 释放并发令牌
|
||||||
|
func ReleaseSemaphore(ctx context.Context, key string) error {
|
||||||
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package service
|
package stat
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
package service
|
package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"model-gateway/common/util"
|
"model-gateway/common/util"
|
||||||
|
"model-gateway/service/queue"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"model-gateway/dao"
|
"model-gateway/dao"
|
||||||
@@ -20,10 +21,11 @@ var Task = &taskService{}
|
|||||||
|
|
||||||
type taskService struct{}
|
type taskService struct{}
|
||||||
|
|
||||||
|
// Create 创建任务
|
||||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||||
startAt := time.Now()
|
startAt := time.Now()
|
||||||
// 固化 token/user 等信息
|
taskID := uuid.NewString()
|
||||||
ctx = util.AsyncCtx(ctx)
|
|
||||||
// 1) 检查模型配置
|
// 1) 检查模型配置
|
||||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
@@ -35,11 +37,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
return nil, errors.New("模型不存在或未启用")
|
return nil, errors.New("模型不存在或未启用")
|
||||||
}
|
}
|
||||||
|
|
||||||
taskID := uuid.NewString()
|
|
||||||
// 2) 排队上限(严格控制:Redis 原子闸门)
|
// 2) 排队上限(严格控制:Redis 原子闸门)
|
||||||
limit := GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
limit := queue.GetRuntimeQueueLimit(ctx, req.ModelName, m.QueueLimit)
|
||||||
if limit > 0 {
|
if limit > 0 {
|
||||||
ok, err := AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
ok, err := queue.AcquireQueueSlot(ctx, req.ModelName, taskID, limit, m.ExpectedSeconds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -48,13 +49,12 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
// 3) 插入任务记录
|
||||||
storedPayload := map[string]any{
|
storedPayload := map[string]any{
|
||||||
"payload": req.RequestPayload,
|
"payload": req.RequestPayload,
|
||||||
"headers": util.ForwardHeaders(ctx),
|
"headers": util.ForwardHeaders(ctx),
|
||||||
}
|
}
|
||||||
|
_, err = dao.Task.Insert(ctx, &entity.AsynchTask{
|
||||||
t := &entity.AsynchTask{
|
|
||||||
ModelName: req.ModelName,
|
ModelName: req.ModelName,
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
State: 0,
|
State: 0,
|
||||||
@@ -64,21 +64,20 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
InputRef: req.InputRef,
|
InputRef: req.InputRef,
|
||||||
RequestPayload: storedPayload,
|
RequestPayload: storedPayload,
|
||||||
EpicycleId: req.EpicycleId,
|
EpicycleId: req.EpicycleId,
|
||||||
}
|
})
|
||||||
_, err = dao.Task.Insert(ctx, t)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 入库失败:回滚闸门占位
|
// 入库失败:回滚闸门占位
|
||||||
ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
queue.ReleaseQueueSlot(ctx, req.ModelName, taskID)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 写操作日志(尽量不影响主流程,失败忽略)
|
// 4) 写操作日志(不影响主流程,失败忽略)
|
||||||
ip := ""
|
ip := ""
|
||||||
ua := ""
|
ua := ""
|
||||||
apiPath := "/task/createTask"
|
apiPath := "/task/createTask"
|
||||||
httpMethod := "POST"
|
httpMethod := "POST"
|
||||||
if r := g.RequestFromCtx(ctx); r != nil {
|
if r := g.RequestFromCtx(ctx); r != nil {
|
||||||
ip = r.GetClientIp()
|
ip = util.GetLocalIP()
|
||||||
ua = r.UserAgent()
|
ua = r.UserAgent()
|
||||||
apiPath = r.URL.Path
|
apiPath = r.URL.Path
|
||||||
httpMethod = r.Method
|
httpMethod = r.Method
|
||||||
@@ -101,70 +100,68 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
// 5) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||||
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
||||||
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req)
|
go s.pollAndRunUntilPicked(util.AsyncCtx(ctx), taskID, req)
|
||||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
|
// pollAndRunUntilPicked 定向轮询执行刚创建的任务
|
||||||
// - 目标:尽快把刚创建的任务拉起来执行
|
// - 目标:尽快把刚创建的任务拉起来执行
|
||||||
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
||||||
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
||||||
// - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
// - 不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||||
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
|
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, req *dto.CreateTaskReq) {
|
||||||
if taskID == "" {
|
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds", 5).Int()
|
||||||
return
|
pollTimeout := g.Cfg().MustGet(ctx, "asynch.worker.pollTimeoutSeconds", 300).Int()
|
||||||
}
|
pollCtx, cancel := context.WithTimeout(ctx, time.Duration(pollTimeout)*time.Second)
|
||||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
|
defer cancel()
|
||||||
if interval <= 0 {
|
|
||||||
interval = 5
|
|
||||||
}
|
|
||||||
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
|
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
g.Log().Infof(ctx, "[任务自动执行][开始] taskId=%s 轮询间隔=%ds 超时=%ds", taskID, interval, pollTimeout)
|
||||||
tryRun := func() bool {
|
tryRun := func() bool {
|
||||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=查询失败 err=%v", taskID, err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if t == nil {
|
if t == nil {
|
||||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
|
g.Log().Warningf(ctx, "[任务自动执行][停止] taskId=%s 原因=任务不存在", taskID)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
switch t.State {
|
switch t.State {
|
||||||
case 0:
|
case 0:
|
||||||
|
//RunByTaskID 尝试执行任务
|
||||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
|
if err = AsyncWorker.RunByTaskID(ctx, taskID, req); err != nil {
|
||||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
g.Log().Warningf(ctx, "[任务自动执行][重试] taskId=%s 状态=待处理 err=%v", taskID, err)
|
||||||
} else {
|
} else {
|
||||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
g.Log().Infof(ctx, "[任务自动执行][已触发] taskId=%s 状态=待处理", taskID)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
case 1:
|
case 1:
|
||||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
|
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=执行中", taskID)
|
||||||
return true
|
return true
|
||||||
case 2, 3, 4:
|
case 2, 3, 4:
|
||||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
|
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=终态 状态=%d", taskID, t.State)
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
|
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=未知状态 状态=%d", taskID, t.State)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 立即尝试一次
|
||||||
// 先立即尝试一次
|
|
||||||
if stop := tryRun(); stop {
|
if stop := tryRun(); stop {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-pollCtx.Done():
|
||||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
|
g.Log().Infof(ctx, "[任务自动执行][停止] taskId=%s 原因=轮询超时", taskID)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if stop := tryRun(); stop {
|
if stop := tryRun(); stop {
|
||||||
@@ -174,6 +171,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetResult 获取任务结果
|
||||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||||
TaskID: taskID,
|
TaskID: taskID,
|
||||||
@@ -244,6 +242,7 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
|||||||
return &dto.GetTaskBatchRes{List: items}, nil
|
return &dto.GetTaskBatchRes{List: items}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List 获取任务列表
|
||||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||||
pageNum, pageSize := 1, 10
|
pageNum, pageSize := 1, 10
|
||||||
if req != nil {
|
if req != nil {
|
||||||
@@ -1,12 +1,17 @@
|
|||||||
package service
|
package task
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"model-gateway/common/util"
|
"model-gateway/common/util"
|
||||||
"model-gateway/model/dto"
|
"model-gateway/model/dto"
|
||||||
"model-gateway/service/gateway"
|
"model-gateway/service/gateway"
|
||||||
|
"model-gateway/service/queue"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -56,7 +61,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dt
|
|||||||
if e != nil {
|
if e != nil {
|
||||||
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
|
task.ErrorMsg = fmt.Sprintf("worker panic: %v", e)
|
||||||
_ = dao.Task.UpdateFailedGlobal(ctx, task)
|
_ = dao.Task.UpdateFailedGlobal(ctx, task)
|
||||||
ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
queue.ReleaseQueueSlot(ctx, task.ModelName, task.TaskID)
|
||||||
}
|
}
|
||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
})
|
})
|
||||||
@@ -100,8 +105,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
|||||||
|
|
||||||
// 2) 分布式并发控制
|
// 2) 分布式并发控制
|
||||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
maxC := queue.GetRuntimeMaxConcurrency(ctx, t.ModelName, model.MaxConcurrency)
|
||||||
acquired, err := acquireSemaphore(ctx, semKey, maxC, 3600)
|
acquired, err := queue.AcquireSemaphore(ctx, semKey, maxC, 3600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.failTask(ctx, t, err.Error())
|
w.failTask(ctx, t, err.Error())
|
||||||
return
|
return
|
||||||
@@ -111,7 +116,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
|||||||
_ = w.rollbackToPending(ctx, t.Id)
|
_ = w.rollbackToPending(ctx, t.Id)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = releaseSemaphore(ctx, semKey) }()
|
defer func() { _ = queue.ReleaseSemaphore(ctx, semKey) }()
|
||||||
|
|
||||||
// 3) request_payload 校验
|
// 3) request_payload 校验
|
||||||
if payload == nil {
|
if payload == nil {
|
||||||
@@ -146,31 +151,32 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6) 解析校验(可重试,失败重新调模型)
|
// 6) 解析校验(可重试,失败重新调模型)
|
||||||
if req.BuildType == 1 {
|
//if req.BuildType == 1 {
|
||||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
// for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||||
if attempt > 0 {
|
// if attempt > 0 {
|
||||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||||
}
|
// }
|
||||||
err = util.ValidatePromptResult(textResult, model.RequestMapping)
|
// // 6.1) 校验数据
|
||||||
if err == nil {
|
// err = util.ValidatePromptResult(textResult, model)
|
||||||
break
|
// if err == nil {
|
||||||
}
|
// break
|
||||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
// }
|
||||||
t.TaskID, attempt, maxRetry, err)
|
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||||
if attempt == maxRetry {
|
// t.TaskID, attempt, maxRetry, err)
|
||||||
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
// if attempt == maxRetry {
|
||||||
return
|
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||||
}
|
// return
|
||||||
// 重新调模型
|
// }
|
||||||
newResult, modelErr := w.callModel(ctx, t, model, payload)
|
// // 6.2) 重新调模型
|
||||||
if modelErr != nil {
|
// newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||||
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
// if modelErr != nil {
|
||||||
t.TaskID, attempt, maxRetry, modelErr)
|
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||||
continue
|
// t.TaskID, attempt, maxRetry, modelErr)
|
||||||
}
|
// continue
|
||||||
textResult = newResult
|
// }
|
||||||
}
|
// textResult = newResult
|
||||||
}
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
// 7) 成功回调
|
// 7) 成功回调
|
||||||
t.State = 2
|
t.State = 2
|
||||||
@@ -185,7 +191,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
if req.EpicycleId != 0 {
|
if req.EpicycleId != 0 {
|
||||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
|
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, req.EpicycleId)
|
||||||
@@ -198,29 +204,29 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
|||||||
|
|
||||||
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
||||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||||
func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||||||
var data []byte
|
var data []byte
|
||||||
var contentType, ext, textResult string
|
var contentType, ext, textResult string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||||
data, err = os.ReadFile(t.TmpFile)
|
data, err = os.ReadFile(task.TmpFile)
|
||||||
if err != nil || len(data) == 0 {
|
if err != nil || len(data) == 0 {
|
||||||
data = nil
|
data = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if data == nil {
|
if data == nil {
|
||||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, data, ext)
|
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, ext)
|
||||||
if tmpErr == nil && tmpPath != "" {
|
if tmpErr == nil && tmpPath != "" {
|
||||||
t.TmpFile = tmpPath
|
task.TmpFile = tmpPath
|
||||||
t.Phase = 1
|
task.Phase = 1
|
||||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -228,10 +234,138 @@ func (w *asyncWorker) callModel(ctx context.Context, t *entity.AsynchTask, m *en
|
|||||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||||
textResult = string(data)
|
textResult = string(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
return gjson.New(textResult).Map(), nil
|
return gjson.New(textResult).Map(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InvokeModel 调用模型服务,返回二进制结果
|
||||||
|
// modelKey 用于覆盖/补充模型配置 head_msg(例如每次请求携带不同的 X-API-Key)
|
||||||
|
func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[string]any, modelKey string) ([]byte, error) {
|
||||||
|
// 1)请求参数映射:将标准 payload 按模型配置的 requestMapping 转为模型需要的格式
|
||||||
|
//mappedPayload := util.ReverseMap(model.RequestMapping, payload)
|
||||||
|
|
||||||
|
// 2)构建请求 URL 和超时
|
||||||
|
baseURL := strings.TrimRight(model.BaseURL, "/")
|
||||||
|
timeout := time.Duration(model.TimeoutSeconds) * time.Second
|
||||||
|
client := &http.Client{Timeout: timeout}
|
||||||
|
method := strings.ToUpper(strings.TrimSpace(model.HttpMethod))
|
||||||
|
|
||||||
|
// 3)构建 HTTP 请求
|
||||||
|
var req *http.Request
|
||||||
|
switch method {
|
||||||
|
case http.MethodGet:
|
||||||
|
q, err := util.PayloadToQuery(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(q) > 0 {
|
||||||
|
if strings.Contains(baseURL, "?") {
|
||||||
|
baseURL = baseURL + "&" + q.Encode()
|
||||||
|
} else {
|
||||||
|
baseURL = baseURL + "?" + q.Encode()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req, err = http.NewRequestWithContext(ctx, http.MethodGet, baseURL, nil)
|
||||||
|
default:
|
||||||
|
bodyBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req, err = http.NewRequestWithContext(ctx, http.MethodPost, baseURL, bytes.NewReader(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4)注入请求头:先模型静态配置,再动态 modelKey(后者可覆盖前者)
|
||||||
|
for hk, hv := range util.ParseHeadMsgHeaders(model.HeadMsg) {
|
||||||
|
req.Header.Set(hk, hv)
|
||||||
|
}
|
||||||
|
for hk, hv := range util.ParseHeadMsgHeaders(modelKey) {
|
||||||
|
req.Header.Set(hk, hv)
|
||||||
|
}
|
||||||
|
if method != http.MethodGet {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5)发送请求
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 6)读取响应体
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7)检查 HTTP 状态码
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
msg := string(b)
|
||||||
|
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8)响应参数映射
|
||||||
|
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
return mappedResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// // InvokeModel 调用模型服务,返回二进制结果
|
||||||
|
//
|
||||||
|
// func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||||
|
// if m == nil || m.BaseURL == "" {
|
||||||
|
// return nil, fmt.Errorf("模型配置不完整")
|
||||||
|
// }
|
||||||
|
// // 请求参数映射
|
||||||
|
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||||
|
// }
|
||||||
|
// // 合并请求头
|
||||||
|
// headers := util.ForwardHeaders(ctx)
|
||||||
|
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||||
|
// headers[hk] = hv
|
||||||
|
// }
|
||||||
|
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||||
|
// headers[hk] = hv
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // 设置超时
|
||||||
|
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||||
|
// if timeout <= 0 {
|
||||||
|
// timeout = 600 * time.Second
|
||||||
|
// }
|
||||||
|
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||||||
|
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||||
|
// if method == "" {
|
||||||
|
// method = http.MethodPost
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// var respBytes []byte
|
||||||
|
//
|
||||||
|
// switch method {
|
||||||
|
// case http.MethodGet:
|
||||||
|
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||||
|
// default:
|
||||||
|
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||||
|
// }
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, err
|
||||||
|
// }
|
||||||
|
// // 响应参数映射
|
||||||
|
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||||||
|
// if err != nil {
|
||||||
|
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||||
|
// return respBytes, nil
|
||||||
|
// }
|
||||||
|
// return mappedResponse, nil
|
||||||
|
// }
|
||||||
|
|
||||||
// uploadOSS 从临时文件上传 OSS
|
// uploadOSS 从临时文件上传 OSS
|
||||||
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
func (w *asyncWorker) uploadOSS(ctx context.Context, t *entity.AsynchTask) (*gateway.UploadFileResponse, error) {
|
||||||
data, err := os.ReadFile(t.TmpFile)
|
data, err := os.ReadFile(t.TmpFile)
|
||||||
@@ -247,7 +381,7 @@ func (w *asyncWorker) failTask(ctx context.Context, t *entity.AsynchTask, errMsg
|
|||||||
t.State = 3
|
t.State = 3
|
||||||
t.ErrorMsg = errMsg
|
t.ErrorMsg = errMsg
|
||||||
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
_ = dao.Task.UpdateFailedGlobal(ctx, t)
|
||||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
queue.ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||||
}
|
}
|
||||||
|
|
||||||
Reference in New Issue
Block a user