gatway
This commit is contained in:
@@ -2,87 +2,86 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"encoding/json"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// triggerSuccessCallback 任务成功后的回调钩子:
|
||||
// - 使用 GET 请求
|
||||
// - 回调地址为 callbackUrl + "/" + bizName
|
||||
// - query 参数:task_id/state/oss_file/file_type
|
||||
// 注意:回调失败不影响任务主流程,只记录日志。
|
||||
func triggerSuccessCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
if t == nil {
|
||||
return
|
||||
// triggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
callbackURL := t.BizName + t.CallbackURL
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
callbackURL := strings.TrimSpace(t.CallbackURL)
|
||||
bizName := strings.TrimSpace(t.BizName)
|
||||
if callbackURL == "" || bizName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(callbackURL)
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] invalid callbackUrl=%s err=%v", callbackURL, err)
|
||||
return
|
||||
}
|
||||
// 必须是可发起 HTTP 请求的绝对地址
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
g.Log().Warningf(ctx, "[callback] callbackUrl must be absolute http(s) url, got=%s", callbackURL)
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
// path 末尾拼接 bizName
|
||||
bizSeg := url.PathEscape(bizName)
|
||||
if strings.HasSuffix(u.Path, "/") || u.Path == "" {
|
||||
u.Path = strings.TrimRight(u.Path, "/") + "/" + bizSeg
|
||||
} else {
|
||||
u.Path = u.Path + "/" + bizSeg
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("task_id", t.TaskID)
|
||||
q.Set("state", fmt.Sprintf("%d", t.State))
|
||||
q.Set("oss_file", t.OssFile)
|
||||
q.Set("file_type", t.FileType)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] build request failed url=%s err=%v", u.String(), err)
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
|
||||
return
|
||||
}
|
||||
// 透传必要头部(如 Authorization / X-User-Info)
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[callback] request failed url=%s err=%v", u.String(), err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
g.Log().Warningf(ctx, "[callback] non-2xx url=%s code=%d body=%s", u.String(), resp.StatusCode, msg)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[callback] success url=%s code=%d", u.String(), resp.StatusCode)
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// triggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := forwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
// IsAdmin 调用admin-go服务检查是否是管理员
|
||||
func IsAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := forwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
@@ -11,6 +11,10 @@ func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// http.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
@@ -24,12 +28,20 @@ func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
return ct, "." + parts[1]
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,4 +51,3 @@ func forwardHeaders(ctx context.Context) map[string]string {
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,11 @@ import (
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
@@ -100,11 +105,14 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
||||
if strings.TrimSpace(m.Route) == "" {
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
|
||||
// ============ 新增:请求参数映射 ============
|
||||
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(m.BaseURL, "/")
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
@@ -118,11 +126,10 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(payload)
|
||||
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -135,7 +142,7 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -145,20 +152,16 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 先注入模型配置 head_msg(静态头部)
|
||||
// 先注入模型配置 head_msg(静态头部,适合公共模型固定 API Key)
|
||||
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
// 透传必要头部(如 Authorization / X-User-Info)
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
// 最后注入动态 modelKey(覆盖/补充静态 head_msg)
|
||||
|
||||
// 最后注入动态 modelKey(允许覆盖/补充静态 head_msg),适合按请求动态传密钥。
|
||||
for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
@@ -174,12 +177,241 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// 尽量把错误体带回去,方便排查
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
|
||||
// ============ 新增:响应参数映射 ============
|
||||
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
|
||||
if err != nil {
|
||||
// 响应映射失败不阻塞,返回原始数据
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
// =========================================
|
||||
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 映射相关函数
|
||||
// ============================================
|
||||
|
||||
// mapRequestPayload 将标准请求映射为模型特定格式
|
||||
func mapRequestPayload(mappingAny any, payload any) (any, error) {
|
||||
// 1. 解析请求映射配置(值是any类型,支持bool、number等)
|
||||
mapping, err := parseRequestMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果没有映射配置,直接返回原始payload
|
||||
if len(mapping) == 0 {
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 2. 将payload转为map
|
||||
var payloadMap map[string]any
|
||||
switch v := payload.(type) {
|
||||
case map[string]any:
|
||||
payloadMap = v
|
||||
case []map[string]any:
|
||||
// 如果传进来的是纯messages数组,包装成标准格式
|
||||
payloadMap = map[string]any{
|
||||
"messages": v,
|
||||
}
|
||||
default:
|
||||
// 通过JSON转换
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
|
||||
return nil, fmt.Errorf("反序列化payload失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用数据库固定参数覆盖/补充
|
||||
for key, value := range mapping {
|
||||
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
|
||||
payloadMap[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return payloadMap, nil
|
||||
}
|
||||
|
||||
// mapResponsePayload 将模型响应映射为标准格式
|
||||
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
|
||||
mapping, err := parseResponseMapping(mappingAny)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseStr := string(responseBytes)
|
||||
resultStr := `{}`
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
value := gjson.Get(responseStr, modelPath)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
|
||||
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(resultStr), nil
|
||||
}
|
||||
|
||||
func parseRequestMapping(mappingAny any) (map[string]any, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make(map[string]any)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
// 尝试转成 map
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
result[k] = val
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
// 尝试转成 string
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return nil, nil
|
||||
// =======================================================
|
||||
|
||||
case map[string]interface{}:
|
||||
result = v
|
||||
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseResponseMapping 解析响应映射配置
|
||||
// 返回值类型为 map[string]string,值都是JSON路径字符串
|
||||
func parseResponseMapping(mappingAny any) (map[string]string, error) {
|
||||
if mappingAny == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mapping := make(map[string]string)
|
||||
|
||||
switch v := mappingAny.(type) {
|
||||
case *gvar.Var:
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
if m := v.Map(); m != nil {
|
||||
for k, val := range m {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
if s := v.String(); s != "" && s != "{}" && s != "null" {
|
||||
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
return mapping, nil
|
||||
}
|
||||
return nil, nil
|
||||
case string:
|
||||
if v == "" || v == "{}" || v == "null" {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
|
||||
}
|
||||
|
||||
case map[string]interface{}:
|
||||
// 数据库JSONB直接返回的map
|
||||
for k, val := range v {
|
||||
if strVal, ok := val.(string); ok {
|
||||
mapping[k] = strVal
|
||||
}
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if err := json.Unmarshal(v, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
jsonBytes, err := json.Marshal(mappingAny)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
|
||||
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
// isEmptyValue 判断值是否为空
|
||||
func isEmptyValue(v any) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return val == ""
|
||||
case []any:
|
||||
return len(val) == 0
|
||||
case map[string]any:
|
||||
return len(val) == 0
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,15 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
@@ -15,40 +20,28 @@ type modelService struct{}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
m := &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelsType: normalizeModelsType(req.ModelsType),
|
||||
BaseURL: req.BaseURL,
|
||||
Route: req.Route,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSecs: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
if m.HttpMethod == "" {
|
||||
m.HttpMethod = "POST"
|
||||
}
|
||||
if m.Enabled == 0 {
|
||||
m.Enabled = 1
|
||||
}
|
||||
if m.MaxConcurrency <= 0 {
|
||||
m.MaxConcurrency = 10
|
||||
}
|
||||
if m.QueueLimit <= 0 {
|
||||
m.QueueLimit = 1000
|
||||
}
|
||||
if m.TimeoutSeconds <= 0 {
|
||||
m.TimeoutSeconds = 60
|
||||
}
|
||||
if m.AutoCleanSeconds <= 0 {
|
||||
m.AutoCleanSeconds = 86400
|
||||
ModelName: req.ModelName,
|
||||
ModelsType: req.ModelsType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
IsPrivate: req.IsPrivate,
|
||||
Enabled: req.Enabled,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
TokenMapping: req.TokenMapping,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, m)
|
||||
if err != nil {
|
||||
@@ -58,68 +51,223 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
data := map[string]any{}
|
||||
if req.BaseURL != "" {
|
||||
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
|
||||
//根据当前 isChatModel 来判断是否更新模型
|
||||
if req.IsChatModel == 1 {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//判断当前用户是否有会话模型
|
||||
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if model != nil {
|
||||
return errors.New("用户已存在会话模型,不能创建新的会话模型")
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
return err
|
||||
}
|
||||
if req.Route != "" {
|
||||
data[entity.AsynchModelCol.Route] = req.Route
|
||||
}
|
||||
if req.HttpMethod != nil && *req.HttpMethod != "" {
|
||||
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
|
||||
}
|
||||
if req.HeadMsg != nil {
|
||||
data[entity.AsynchModelCol.HeadMsg] = *req.HeadMsg
|
||||
}
|
||||
if req.Form != nil {
|
||||
data[entity.AsynchModelCol.FormJSON] = req.Form
|
||||
}
|
||||
if req.ModelsType != nil {
|
||||
data[entity.AsynchModelCol.ModelsType] = normalizeModelsType(*req.ModelsType)
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
data[entity.AsynchModelCol.Enabled] = *req.Enabled
|
||||
}
|
||||
if req.MaxConcurrency != nil {
|
||||
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
|
||||
}
|
||||
if req.QueueLimit != nil {
|
||||
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
|
||||
}
|
||||
if req.TimeoutSeconds != nil {
|
||||
data[entity.AsynchModelCol.TimeoutSeconds] = *req.TimeoutSeconds
|
||||
}
|
||||
if req.ExpectedSeconds != nil {
|
||||
data[entity.AsynchModelCol.ExpectedSeconds] = *req.ExpectedSeconds
|
||||
}
|
||||
if req.RetryTimes != nil {
|
||||
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
|
||||
}
|
||||
if req.RetryQueueMaxSeconds != nil {
|
||||
data[entity.AsynchModelCol.RetryQueueMaxSecs] = *req.RetryQueueMaxSeconds
|
||||
}
|
||||
if req.AutoCleanSeconds != nil {
|
||||
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
|
||||
}
|
||||
if req.Remark != nil {
|
||||
data[entity.AsynchModelCol.Remark] = *req.Remark
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
|
||||
_, err := dao.Model.Update(ctx, req)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id int64) error {
|
||||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
return dao.Model.GetByID(ctx, id)
|
||||
model, err := dao.Model.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
return dao.Model.List(ctx, pageNum, pageSize, modelNameLike)
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
isSuperAdmin, err := IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var models []*entity.AsynchModel
|
||||
var count int64
|
||||
|
||||
if isSuperAdmin {
|
||||
models, count, err = dao.Model.List(ctx, pageNum, pageSize, modelNameLike, modelType)
|
||||
} else {
|
||||
models, count, err = s.getModelsWithDedup(ctx, user.UserName, pageNum, pageSize, modelNameLike, modelType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 处理列表中每条记录的 JSONB 字段
|
||||
for _, m := range models {
|
||||
m.Form = ParseJSONField(m.Form)
|
||||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
||||
}
|
||||
return models, count, nil
|
||||
}
|
||||
|
||||
// getModelsWithDedup 获取普通用户的模型列表并去重
|
||||
func (s *modelService) getModelsWithDedup(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
// 1. 查全量数据(不分页,便于去重)
|
||||
allModels, err := dao.Model.GetByCreatorAndPlatform(ctx, creator, modelNameLike, modelType)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 2. 按 modelName 去重,保留当前用户的
|
||||
modelMap := make(map[string]*entity.AsynchModel)
|
||||
for _, m := range allModels {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
name := m.ModelName
|
||||
|
||||
_, ok := modelMap[name]
|
||||
if !ok {
|
||||
// 没有冲突,直接放进去
|
||||
modelMap[name] = m
|
||||
} else {
|
||||
// 有冲突,保留当前用户创建的
|
||||
if m.Creator == creator {
|
||||
modelMap[name] = m
|
||||
}
|
||||
// 如果现有的就是当前用户的,不做任何替换
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 转回切片并排序
|
||||
deduped := make([]*entity.AsynchModel, 0, len(modelMap))
|
||||
for _, m := range modelMap {
|
||||
deduped = append(deduped, m)
|
||||
}
|
||||
sort.Slice(deduped, func(i, j int) bool {
|
||||
return deduped[i].CreatedAt.After(deduped[j].CreatedAt)
|
||||
})
|
||||
|
||||
// 4. 手动分页
|
||||
total = int64(len(deduped))
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
start := (pageNum - 1) * pageSize
|
||||
if start >= len(deduped) {
|
||||
return []*entity.AsynchModel{}, total, nil
|
||||
}
|
||||
end := start + pageSize
|
||||
if end > len(deduped) {
|
||||
end = len(deduped)
|
||||
}
|
||||
deduped = deduped[start:end]
|
||||
}
|
||||
return deduped, total, nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
||||
typeMap := make(map[int]string)
|
||||
|
||||
// 读取配置
|
||||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
||||
for k, v := range configMap {
|
||||
typeID := gconv.Int(k)
|
||||
typeName := gconv.String(v)
|
||||
if typeID > 0 && typeName != "" {
|
||||
typeMap[typeID] = typeName
|
||||
}
|
||||
}
|
||||
// 如果配置为空,使用默认值
|
||||
if len(typeMap) == 0 {
|
||||
typeMap = map[int]string{
|
||||
1: "推理模型",
|
||||
2: "图片模型",
|
||||
3: "音频模型",
|
||||
4: "向量化模型",
|
||||
5: "全模态模型",
|
||||
}
|
||||
}
|
||||
return typeMap
|
||||
}
|
||||
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 校验新会话模型是否存在
|
||||
newModel, err := dao.Model.Get(ctx, req.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 获取当前用户会话模型
|
||||
currentModel, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if currentModel.ModelsType != 1 {
|
||||
return errors.New("当前模型为非推理模型,不能设置为会话模型")
|
||||
}
|
||||
|
||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||
if currentModel != nil && currentModel.Id == req.Id {
|
||||
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
|
||||
ID: req.Id,
|
||||
IsChatModel: 0,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果之前有会话模型,取消它(设为0)
|
||||
if currentModel != nil {
|
||||
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
|
||||
ID: currentModel.Id,
|
||||
IsChatModel: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 设置当前为会话模型(设为1)
|
||||
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
|
||||
ID: req.Id,
|
||||
IsChatModel: 1,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
}
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
type modelTypeService struct{}
|
||||
|
||||
var ModelType = &modelTypeService{}
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Create(ctx context.Context, req *dto.CreateModelTypeReq) (res *dto.CreateModelTypeRes, err error) {
|
||||
t := &entity.AsynchModelType{
|
||||
TypeID: req.TypeID,
|
||||
TypeName: req.TypeName,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
id, err := dao.ModelType.Insert(ctx, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelTypeRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Update(ctx context.Context, req *dto.UpdateModelTypeReq) error {
|
||||
data := map[string]any{}
|
||||
if req.TypeID != nil {
|
||||
data[entity.AsynchModelTypeCol.TypeID] = *req.TypeID
|
||||
}
|
||||
if req.TypeName != nil {
|
||||
data[entity.AsynchModelTypeCol.TypeName] = *req.TypeName
|
||||
}
|
||||
if req.Remark != nil {
|
||||
data[entity.AsynchModelTypeCol.Remark] = *req.Remark
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.ModelType.UpdateByID(ctx, req.ID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Delete(ctx context.Context, id int64) error {
|
||||
_, err := dao.ModelType.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelTypeService) Get(ctx context.Context, id int64) (*entity.AsynchModelType, error) {
|
||||
return dao.ModelType.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *modelTypeService) List(ctx context.Context, pageNum, pageSize int, typeNameLike string) (list []*entity.AsynchModelType, total int64, err error) {
|
||||
return dao.ModelType.List(ctx, pageNum, pageSize, typeNameLike)
|
||||
}
|
||||
|
||||
// ListWithModels 按类型分组返回模型(返回数组,便于前端直接渲染)
|
||||
func (s *modelTypeService) ListWithModels(ctx context.Context, req *dto.ListModelTypeWithModelsReq) (res []dto.ModelTypeWithModelsItem, err error) {
|
||||
types, _, err := dao.ModelType.List(ctx, 1, 1000, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 过滤类型(按 typeId / typeName 模糊)
|
||||
filterTypeID := 0
|
||||
filterTypeName := ""
|
||||
if req != nil {
|
||||
filterTypeID = req.TypeID
|
||||
filterTypeName = strings.TrimSpace(req.Type)
|
||||
}
|
||||
typeIDs := make([]int, 0, len(types))
|
||||
typeNameMap := make(map[int]string, len(types))
|
||||
for _, t := range types {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if filterTypeID > 0 && t.TypeID != filterTypeID {
|
||||
continue
|
||||
}
|
||||
if filterTypeName != "" && !strings.Contains(t.TypeName, filterTypeName) {
|
||||
continue
|
||||
}
|
||||
typeIDs = append(typeIDs, t.TypeID)
|
||||
typeNameMap[t.TypeID] = t.TypeName
|
||||
}
|
||||
models, err := dao.Model.ListAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
itemsMap := map[int][]dto.ModelTypeModelItem{}
|
||||
for _, m := range models {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
form := normalizeFormValue(m.Form)
|
||||
// 一个模型可能支持多个类型:models_type="1,2,3"
|
||||
for _, tid := range parseModelsTypeIDs(m.ModelsType) {
|
||||
// 若请求过滤了类型,则只输出该类型
|
||||
if filterTypeID > 0 && tid != filterTypeID {
|
||||
continue
|
||||
}
|
||||
if filterTypeName != "" {
|
||||
if _, ok := typeNameMap[tid]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
itemsMap[tid] = append(itemsMap[tid], dto.ModelTypeModelItem{
|
||||
ID: m.Id,
|
||||
Name: m.ModelName,
|
||||
Form: form,
|
||||
})
|
||||
}
|
||||
}
|
||||
out := make([]dto.ModelTypeWithModelsItem, 0, len(typeIDs))
|
||||
for _, tid := range typeIDs {
|
||||
items := itemsMap[tid]
|
||||
if items == nil {
|
||||
items = make([]dto.ModelTypeModelItem, 0)
|
||||
}
|
||||
out = append(out, dto.ModelTypeWithModelsItem{
|
||||
TypeID: tid,
|
||||
Type: typeNameMap[tid],
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// normalizeModelsType 将 "1, 2,2,3" 归一化为 "1,2,3"
|
||||
// - 去空格
|
||||
// - 去重
|
||||
// - 升序排序
|
||||
func normalizeModelsType(v string) string {
|
||||
ids := parseModelsTypeIDs(v)
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
parts = append(parts, strconv.Itoa(id))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// parseModelsTypeIDs 解析 models_type 字段(支持 "1,2,3"),返回去重后的 int 列表(升序)。
|
||||
func parseModelsTypeIDs(v string) []int {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
raw := strings.Split(v, ",")
|
||||
seen := map[int]struct{}{}
|
||||
out := make([]int, 0, len(raw))
|
||||
for _, s := range raw {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" || s == "0" {
|
||||
continue
|
||||
}
|
||||
id, err := strconv.Atoi(s)
|
||||
if err != nil || id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
sort.Ints(out)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -13,12 +13,12 @@ var Stat = &statService{}
|
||||
|
||||
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
if req != nil {
|
||||
if req.PageNum > 0 {
|
||||
pageNum = req.PageNum
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
if req.PageSize > 0 {
|
||||
pageSize = req.PageSize
|
||||
}
|
||||
}
|
||||
startDay, endDay := "", ""
|
||||
@@ -37,4 +37,3 @@ func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res
|
||||
}
|
||||
return &dto.ListModelStatRes{List: list, Total: total}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -62,7 +62,6 @@ func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, dat
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
fmt.Println("打印结果 resp:", resp)
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
@@ -58,9 +58,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
State: 0,
|
||||
BizName: req.BizName,
|
||||
CallbackURL: req.CallbackUrl,
|
||||
ModelKey: req.ModelKey,
|
||||
ModelKey: m.ApiKey,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
EpicycleId: req.EpicycleId,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, t)
|
||||
if err != nil {
|
||||
@@ -80,7 +81,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
@@ -97,9 +98,80 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
|
||||
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
|
||||
// 一旦任务进入 running/success/failed/downloaded,就停止轮询,避免一直空转。
|
||||
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
|
||||
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
|
||||
// - 目标:尽快把刚创建的任务拉起来执行
|
||||
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
|
||||
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
|
||||
// - 这样不会无限轮询;runWork 仍负责处理积压队列和未处理到的任务
|
||||
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
|
||||
if taskID == "" {
|
||||
return
|
||||
}
|
||||
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
|
||||
if interval <= 0 {
|
||||
interval = 5
|
||||
}
|
||||
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
|
||||
|
||||
ticker := time.NewTicker(time.Duration(interval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
return true
|
||||
}
|
||||
if t == nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
|
||||
return true
|
||||
}
|
||||
switch t.State {
|
||||
case 0:
|
||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
}
|
||||
return false
|
||||
case 1:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
|
||||
return true
|
||||
case 2, 3, 4:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
|
||||
return true
|
||||
default:
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 先立即尝试一次
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
|
||||
return
|
||||
case <-ticker.C:
|
||||
if stop := tryRun(); stop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
@@ -168,12 +240,12 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
if req != nil {
|
||||
if req.PageNum > 0 {
|
||||
pageNum = req.PageNum
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
if req.PageSize > 0 {
|
||||
pageSize = req.PageSize
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
|
||||
113
service/utils.go
Normal file
113
service/utils.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
@@ -5,12 +5,14 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var AsyncWorker = &asyncWorker{}
|
||||
@@ -43,7 +45,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
|
||||
w.handleOne(ctx, task)
|
||||
w.handleOne(ctx, task, 0)
|
||||
done <- struct{}{}
|
||||
}, func(ctx context.Context, e error) {
|
||||
if e != nil {
|
||||
@@ -59,8 +61,23 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
||||
return claimed, nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers,给 OSS 上传透传鉴权用
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
// - 只定向抢占当前 taskId 对应的 pending 任务
|
||||
// - 若任务已被其它 worker 抢走/已不在 pending,则直接返回
|
||||
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
|
||||
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
w.handleOne(ctx, task, epicycleId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
@@ -71,26 +88,42 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
if m == nil || m.Enabled != 1 {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
|
||||
errMsg := "模型不存在或未启用"
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 分布式并发限制(按 model_name 全局维度)
|
||||
// 2) 分布式并发限制
|
||||
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
|
||||
leaseSeconds := int64(3600) // 兜底1小时
|
||||
leaseSeconds := int64(3600)
|
||||
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
|
||||
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
// 并发满了:放回排队(重新置回 state=0),下一轮再抢占
|
||||
// 并发满了:放回排队,不回调(不是失败)
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
return
|
||||
}
|
||||
@@ -109,30 +142,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
data []byte
|
||||
contentType string
|
||||
ext string
|
||||
textResult string
|
||||
)
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
} else {
|
||||
// 临时文件不可用:回退重新调用模型
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
// 统计:仅在真正请求模型时 +1(OSS 重试不计入)
|
||||
// 统计
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
|
||||
|
||||
// 核心调用
|
||||
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
if len(textResult) > 20000 {
|
||||
textResult = textResult[:20000]
|
||||
}
|
||||
}
|
||||
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
||||
if err == nil && tmpPath != "" {
|
||||
t.TmpFile = tmpPath
|
||||
@@ -147,26 +190,46 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// ============ OSS失败不回调(还会重试) ============
|
||||
// 注意:OSS失败保留临时文件,下次重试,所以这里不触发最终回调
|
||||
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 更新任务状态成功
|
||||
// 注意:expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
|
||||
fileType := strings.TrimPrefix(ext, ".")
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
|
||||
if err := dao.Task.UpdateSuccessGlobal(
|
||||
ctx,
|
||||
t.Id,
|
||||
ossURL,
|
||||
fileType,
|
||||
textResult,
|
||||
int64(len(data)),
|
||||
nil,
|
||||
GetExpendTokens(m.TokenMapping, textResult),
|
||||
); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
}
|
||||
// 成功/失败均不再占用 queue_limit(state=0/1 才占用)
|
||||
|
||||
// 成功/失败均不再占用 queue_limit
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
// 6) 成功回调(不影响主流程)
|
||||
|
||||
// 6) 成功回调
|
||||
t.State = 2
|
||||
t.OssFile = ossURL
|
||||
t.FileType = fileType
|
||||
go triggerSuccessCallback(context.WithoutCancel(ctx), t)
|
||||
t.TextResult = textResult
|
||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||
if epicycleId != 0 {
|
||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
}
|
||||
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
}
|
||||
@@ -174,3 +237,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
|
||||
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
|
||||
func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
} else {
|
||||
return len(textResult)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user