This commit is contained in:
2026-05-12 13:45:08 +08:00
parent e81df5ce5a
commit 37d3461983
38 changed files with 1721 additions and 1113 deletions

View File

@@ -2,87 +2,86 @@ package service
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"encoding/json"
"model-asynch/model/entity"
"gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
)
// triggerSuccessCallback 任务成功后的回调钩子
// - 使用 GET 请求
// - 回调地址为 callbackUrl + "/" + bizName
// - query 参数task_id/state/oss_file/file_type
// 注意:回调失败不影响任务主流程,只记录日志。
func triggerSuccessCallback(ctx context.Context, t *entity.AsynchTask) {
if t == nil {
return
// triggerCallback 任务成功后的回调:
// - JSON body 参数task_id/state/oss_file/file_type/text可选
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
callbackURL := t.BizName + t.CallbackURL
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"task_id": t.TaskID,
"state": t.State,
"oss_file": t.OssFile,
"file_type": t.FileType,
"text": t.TextResult,
"error_msg": t.ErrorMsg,
}
callbackURL := strings.TrimSpace(t.CallbackURL)
bizName := strings.TrimSpace(t.BizName)
if callbackURL == "" || bizName == "" {
return
}
u, err := url.Parse(callbackURL)
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[callback] invalid callbackUrl=%s err=%v", callbackURL, err)
return
}
// 必须是可发起 HTTP 请求的绝对地址
if u.Scheme == "" || u.Host == "" {
g.Log().Warningf(ctx, "[callback] callbackUrl must be absolute http(s) url, got=%s", callbackURL)
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
return
}
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.TaskID, callbackURL, len(headers), len(jsonData))
// path 末尾拼接 bizName
bizSeg := url.PathEscape(bizName)
if strings.HasSuffix(u.Path, "/") || u.Path == "" {
u.Path = strings.TrimRight(u.Path, "/") + "/" + bizSeg
} else {
u.Path = u.Path + "/" + bizSeg
}
q := u.Query()
q.Set("task_id", t.TaskID)
q.Set("state", fmt.Sprintf("%d", t.State))
q.Set("oss_file", t.OssFile)
q.Set("file_type", t.FileType)
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[callback] build request failed url=%s err=%v", u.String(), err)
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
return
}
// 透传必要头部(如 Authorization / X-User-Info
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[callback] request failed url=%s err=%v", u.String(), err)
return
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
g.Log().Warningf(ctx, "[callback] non-2xx url=%s code=%d body=%s", u.String(), resp.StatusCode, msg)
return
}
g.Log().Infof(ctx, "[callback] success url=%s code=%d", u.String(), resp.StatusCode)
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
}
// triggerPromptsCallback 任务成功后的提示词回调
// - JSON body 参数epicycleId轮次id/textResult模型回答消息
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
callbackURL := "prompts-core/session/sessionCallback"
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"epicycleId": epicycleId,
"text": t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.EpicycleId, callbackURL, len(headers), len(jsonData))
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := forwardHeaders(ctx)
var r = make(map[string]bool)
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
// IsAdmin 调用admin-go服务检查是否是管理员
func IsAdmin(ctx context.Context) (res bool, err error) {
headers := forwardHeaders(ctx)
var r = make(map[string]bool)
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}

View File

@@ -11,6 +11,10 @@ func DetectFileType(data []byte) (contentType string, ext string) {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
// http.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx])
}
switch ct {
case "audio/mpeg":
return ct, ".mp3"
@@ -24,12 +28,20 @@ func DetectFileType(data []byte) (contentType string, ext string) {
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
case "text/plain":
return ct, ".txt"
case "application/json":
return ct, ".json"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
return ct, "." + parts[1]
sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx])
}
return ct, "." + sub
}
return ct, ""
}
}

View File

@@ -51,4 +51,3 @@ func forwardHeaders(ctx context.Context) map[string]string {
}
return headers
}

View File

@@ -12,6 +12,11 @@ import (
"time"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/frame/g"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// parseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
@@ -100,11 +105,14 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
if strings.TrimSpace(m.Route) == "" {
url = strings.TrimRight(m.BaseURL, "/")
// ============ 新增:请求参数映射 ============
mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
if err != nil {
return nil, fmt.Errorf("请求参数映射失败: %w", err)
}
url := strings.TrimRight(m.BaseURL, "/")
timeout := time.Duration(m.TimeoutSeconds) * time.Second
if timeout <= 0 {
timeout = 60 * time.Second
@@ -118,11 +126,10 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
var (
req *http.Request
err error
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(payload)
q, err := payloadToQuery(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
@@ -135,7 +142,7 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(payload)
bodyBytes, err := json.Marshal(mappedPayload) // 使用映射后的payload
if err != nil {
return nil, err
}
@@ -145,20 +152,16 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
return nil, err
}
// 先注入模型配置 head_msg静态头部
// 先注入模型配置 head_msg静态头部,适合公共模型固定 API Key
for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
req.Header.Set(hk, hv)
}
// 透传必要头部(如 Authorization / X-User-Info
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
// 最后注入动态 modelKey覆盖/补充静态 head_msg
// 最后注入动态 modelKey允许覆盖/补充静态 head_msg适合按请求动态传密钥。
for hk, hv := range parseHeadMsgHeaders(modelKey) {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
@@ -174,12 +177,241 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// 尽量把错误体带回去,方便排查
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
// ============ 新增:响应参数映射 ============
mappedResponse, err := mapResponsePayload(m.ResponseMapping, b)
if err != nil {
// 响应映射失败不阻塞,返回原始数据
g.Log().Warningf(ctx, "响应参数映射失败: %v返回原始数据", err)
return b, nil
}
// =========================================
return mappedResponse, nil
}
// ============================================
// 映射相关函数
// ============================================
// mapRequestPayload 将标准请求映射为模型特定格式
func mapRequestPayload(mappingAny any, payload any) (any, error) {
// 1. 解析请求映射配置值是any类型支持bool、number等
mapping, err := parseRequestMapping(mappingAny)
if err != nil {
return nil, err
}
// 如果没有映射配置直接返回原始payload
if len(mapping) == 0 {
return payload, nil
}
// 2. 将payload转为map
var payloadMap map[string]any
switch v := payload.(type) {
case map[string]any:
payloadMap = v
case []map[string]any:
// 如果传进来的是纯messages数组包装成标准格式
payloadMap = map[string]any{
"messages": v,
}
default:
// 通过JSON转换
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("序列化payload失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &payloadMap); err != nil {
return nil, fmt.Errorf("反序列化payload失败: %w", err)
}
}
// 3. 用数据库固定参数覆盖/补充
for key, value := range mapping {
if existingValue, exists := payloadMap[key]; !exists || isEmptyValue(existingValue) {
payloadMap[key] = value
}
}
return payloadMap, nil
}
// mapResponsePayload 将模型响应映射为标准格式
func mapResponsePayload(mappingAny any, responseBytes []byte) ([]byte, error) {
mapping, err := parseResponseMapping(mappingAny)
if err != nil {
return nil, err
}
if len(mapping) == 0 {
return responseBytes, nil
}
responseStr := string(responseBytes)
resultStr := `{}`
for standardField, modelPath := range mapping {
value := gjson.Get(responseStr, modelPath)
if !value.Exists() {
continue
}
resultStr, err = sjson.SetRaw(resultStr, standardField, value.Raw)
if err != nil {
return nil, fmt.Errorf("提取字段 %s <- %s 失败: %w", standardField, modelPath, err)
}
}
return []byte(resultStr), nil
}
func parseRequestMapping(mappingAny any) (map[string]any, error) {
if mappingAny == nil {
return nil, nil
}
result := make(map[string]any)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
// 尝试转成 map
if m := v.Map(); m != nil {
for k, val := range m {
result[k] = val
}
return result, nil
}
// 尝试转成 string
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
return result, nil
}
return nil, nil
// =======================================================
case map[string]interface{}:
result = v
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &result); err != nil {
return nil, fmt.Errorf("解析请求映射字符串失败: %w", err)
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &result); err != nil {
return nil, fmt.Errorf("解析请求映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("解析映射配置失败: %w", err)
}
}
return result, nil
}
// parseResponseMapping 解析响应映射配置
// 返回值类型为 map[string]string值都是JSON路径字符串
func parseResponseMapping(mappingAny any) (map[string]string, error) {
if mappingAny == nil {
return nil, nil
}
mapping := make(map[string]string)
switch v := mappingAny.(type) {
case *gvar.Var:
if v == nil || v.IsNil() || v.IsEmpty() {
return nil, nil
}
if m := v.Map(); m != nil {
for k, val := range m {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
return mapping, nil
}
if s := v.String(); s != "" && s != "{}" && s != "null" {
if err := json.Unmarshal([]byte(s), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
return mapping, nil
}
return nil, nil
case string:
if v == "" || v == "{}" || v == "null" {
return nil, nil
}
if err := json.Unmarshal([]byte(v), &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字符串失败: %w", err)
}
case map[string]interface{}:
// 数据库JSONB直接返回的map
for k, val := range v {
if strVal, ok := val.(string); ok {
mapping[k] = strVal
}
}
case []byte:
if len(v) == 0 {
return nil, nil
}
if err := json.Unmarshal(v, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射字节失败: %w", err)
}
default:
jsonBytes, err := json.Marshal(mappingAny)
if err != nil {
return nil, fmt.Errorf("序列化响应映射配置失败: %w", err)
}
if err := json.Unmarshal(jsonBytes, &mapping); err != nil {
return nil, fmt.Errorf("解析响应映射配置失败: %w", err)
}
}
return mapping, nil
}
// isEmptyValue 判断值是否为空
func isEmptyValue(v any) bool {
if v == nil {
return true
}
switch val := v.(type) {
case string:
return val == ""
case []any:
return len(val) == 0
case map[string]any:
return len(val) == 0
default:
return false
}
}

View File

@@ -3,10 +3,15 @@ package service
import (
"context"
"errors"
"sort"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
@@ -15,40 +20,28 @@ type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
m := &entity.AsynchModel{
ModelName: req.ModelName,
ModelsType: normalizeModelsType(req.ModelsType),
BaseURL: req.BaseURL,
Route: req.Route,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSecs: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
}
if m.HttpMethod == "" {
m.HttpMethod = "POST"
}
if m.Enabled == 0 {
m.Enabled = 1
}
if m.MaxConcurrency <= 0 {
m.MaxConcurrency = 10
}
if m.QueueLimit <= 0 {
m.QueueLimit = 1000
}
if m.TimeoutSeconds <= 0 {
m.TimeoutSeconds = 60
}
if m.AutoCleanSeconds <= 0 {
m.AutoCleanSeconds = 86400
ModelName: req.ModelName,
ModelsType: req.ModelsType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
IsPrivate: req.IsPrivate,
Enabled: req.Enabled,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
TokenMapping: req.TokenMapping,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
}
id, err := dao.Model.Insert(ctx, m)
if err != nil {
@@ -58,68 +51,223 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
data := map[string]any{}
if req.BaseURL != "" {
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == 1 {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
//判断当前用户是否有会话模型
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型,不能创建新的会话模型")
}
_, err = dao.Model.Update(ctx, req)
return err
}
if req.Route != "" {
data[entity.AsynchModelCol.Route] = req.Route
}
if req.HttpMethod != nil && *req.HttpMethod != "" {
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
}
if req.HeadMsg != nil {
data[entity.AsynchModelCol.HeadMsg] = *req.HeadMsg
}
if req.Form != nil {
data[entity.AsynchModelCol.FormJSON] = req.Form
}
if req.ModelsType != nil {
data[entity.AsynchModelCol.ModelsType] = normalizeModelsType(*req.ModelsType)
}
if req.Enabled != nil {
data[entity.AsynchModelCol.Enabled] = *req.Enabled
}
if req.MaxConcurrency != nil {
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
}
if req.QueueLimit != nil {
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
}
if req.TimeoutSeconds != nil {
data[entity.AsynchModelCol.TimeoutSeconds] = *req.TimeoutSeconds
}
if req.ExpectedSeconds != nil {
data[entity.AsynchModelCol.ExpectedSeconds] = *req.ExpectedSeconds
}
if req.RetryTimes != nil {
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
}
if req.RetryQueueMaxSeconds != nil {
data[entity.AsynchModelCol.RetryQueueMaxSecs] = *req.RetryQueueMaxSeconds
}
if req.AutoCleanSeconds != nil {
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
}
if req.Remark != nil {
data[entity.AsynchModelCol.Remark] = *req.Remark
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
_, err := dao.Model.Update(ctx, req)
return err
}
func (s *modelService) Delete(ctx context.Context, id int64) error {
func (s *modelService) Delete(ctx context.Context, id string) error {
_, err := dao.Model.DeleteByID(ctx, id)
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
return dao.Model.GetByID(ctx, id)
model, err := dao.Model.Get(ctx, id)
if err != nil {
return nil, err
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
return dao.Model.List(ctx, pageNum, pageSize, modelNameLike)
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) {
isSuperAdmin, err := IsSuperAdmin(ctx)
if err != nil {
return nil, 0, err
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, 0, err
}
var models []*entity.AsynchModel
var count int64
if isSuperAdmin {
models, count, err = dao.Model.List(ctx, pageNum, pageSize, modelNameLike, modelType)
} else {
models, count, err = s.getModelsWithDedup(ctx, user.UserName, pageNum, pageSize, modelNameLike, modelType)
}
if err != nil {
return nil, 0, err
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = ParseJSONField(m.Form)
m.RequestMapping = ParseJSONField(m.RequestMapping)
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
m.ResponseBody = ParseJSONField(m.ResponseBody)
}
return models, count, nil
}
// getModelsWithDedup 获取普通用户的模型列表并去重
func (s *modelService) getModelsWithDedup(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string, modelType int) (list []*entity.AsynchModel, total int64, err error) {
// 1. 查全量数据(不分页,便于去重)
allModels, err := dao.Model.GetByCreatorAndPlatform(ctx, creator, modelNameLike, modelType)
if err != nil {
return nil, 0, err
}
// 2. 按 modelName 去重,保留当前用户的
modelMap := make(map[string]*entity.AsynchModel)
for _, m := range allModels {
if m == nil {
continue
}
name := m.ModelName
_, ok := modelMap[name]
if !ok {
// 没有冲突,直接放进去
modelMap[name] = m
} else {
// 有冲突,保留当前用户创建的
if m.Creator == creator {
modelMap[name] = m
}
// 如果现有的就是当前用户的,不做任何替换
}
}
// 3. 转回切片并排序
deduped := make([]*entity.AsynchModel, 0, len(modelMap))
for _, m := range modelMap {
deduped = append(deduped, m)
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].CreatedAt.After(deduped[j].CreatedAt)
})
// 4. 手动分页
total = int64(len(deduped))
if pageNum > 0 && pageSize > 0 {
start := (pageNum - 1) * pageSize
if start >= len(deduped) {
return []*entity.AsynchModel{}, total, nil
}
end := start + pageSize
if end > len(deduped) {
end = len(deduped)
}
deduped = deduped[start:end]
}
return deduped, total, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
typeMap := make(map[int]string)
// 读取配置
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
for k, v := range configMap {
typeID := gconv.Int(k)
typeName := gconv.String(v)
if typeID > 0 && typeName != "" {
typeMap[typeID] = typeName
}
}
// 如果配置为空,使用默认值
if len(typeMap) == 0 {
typeMap = map[int]string{
1: "推理模型",
2: "图片模型",
3: "音频模型",
4: "向量化模型",
5: "全模态模型",
}
}
return typeMap
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 校验新会话模型是否存在
newModel, err := dao.Model.Get(ctx, req.Id)
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
// 获取当前用户会话模型
currentModel, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return err
}
if currentModel.ModelsType != 1 {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel != nil && currentModel.Id == req.Id {
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: req.Id,
IsChatModel: 0,
})
return err
}
// 如果之前有会话模型取消它设为0
if currentModel != nil {
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: currentModel.Id,
IsChatModel: 0,
})
if err != nil {
return err
}
}
// 设置当前为会话模型设为1
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: req.Id,
IsChatModel: 1,
})
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}

View File

@@ -1,217 +0,0 @@
package service
import (
"context"
"encoding/json"
"errors"
"strings"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/container/gvar"
)
type modelTypeService struct{}
var ModelType = &modelTypeService{}
func normalizeFormValue(v any) any {
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
if v == nil {
return []any{}
}
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
case []byte:
if len(t) == 0 {
return []any{}
}
return normalizeFormValueFromJSONBytes(t)
case *gvar.Var:
// goframe 常见的 DB 返回类型
if t == nil {
return []any{}
}
b := t.Bytes()
if len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
s := strings.TrimSpace(t.String())
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
default:
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
if vb, ok := v.(interface{ Bytes() []byte }); ok {
if b := vb.Bytes(); len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
}
if vs, ok := v.(interface{ String() string }); ok {
if s := strings.TrimSpace(vs.String()); s != "" {
return normalizeFormValueFromJSONString(s)
}
}
// 已经是 []any / map[string]any 等结构
return v
}
}
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
func normalizeFormValueFromJSONString(s string) any {
var out any
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
return []any{}
}
// 如果解出来还是 string且看起来是 JSON再解一层
if inner, ok := out.(string); ok {
inner = strings.TrimSpace(inner)
if inner == "" {
return []any{}
}
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
var out2 any
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
return out2
}
}
return []any{}
}
return out
}
func normalizeFormValueFromJSONBytes(b []byte) any {
var out any
if err := json.Unmarshal(b, &out); err != nil || out == nil {
return []any{}
}
// bytes 解出来也可能是 string同上
if inner, ok := out.(string); ok {
return normalizeFormValueFromJSONString(inner)
}
return out
}
func (s *modelTypeService) Create(ctx context.Context, req *dto.CreateModelTypeReq) (res *dto.CreateModelTypeRes, err error) {
t := &entity.AsynchModelType{
TypeID: req.TypeID,
TypeName: req.TypeName,
Remark: req.Remark,
}
id, err := dao.ModelType.Insert(ctx, t)
if err != nil {
return nil, err
}
return &dto.CreateModelTypeRes{ID: id}, nil
}
func (s *modelTypeService) Update(ctx context.Context, req *dto.UpdateModelTypeReq) error {
data := map[string]any{}
if req.TypeID != nil {
data[entity.AsynchModelTypeCol.TypeID] = *req.TypeID
}
if req.TypeName != nil {
data[entity.AsynchModelTypeCol.TypeName] = *req.TypeName
}
if req.Remark != nil {
data[entity.AsynchModelTypeCol.Remark] = *req.Remark
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.ModelType.UpdateByID(ctx, req.ID, data)
return err
}
func (s *modelTypeService) Delete(ctx context.Context, id int64) error {
_, err := dao.ModelType.DeleteByID(ctx, id)
return err
}
func (s *modelTypeService) Get(ctx context.Context, id int64) (*entity.AsynchModelType, error) {
return dao.ModelType.GetByID(ctx, id)
}
func (s *modelTypeService) List(ctx context.Context, pageNum, pageSize int, typeNameLike string) (list []*entity.AsynchModelType, total int64, err error) {
return dao.ModelType.List(ctx, pageNum, pageSize, typeNameLike)
}
// ListWithModels 按类型分组返回模型(返回数组,便于前端直接渲染)
func (s *modelTypeService) ListWithModels(ctx context.Context, req *dto.ListModelTypeWithModelsReq) (res []dto.ModelTypeWithModelsItem, err error) {
types, _, err := dao.ModelType.List(ctx, 1, 1000, "")
if err != nil {
return nil, err
}
// 过滤类型(按 typeId / typeName 模糊)
filterTypeID := 0
filterTypeName := ""
if req != nil {
filterTypeID = req.TypeID
filterTypeName = strings.TrimSpace(req.Type)
}
typeIDs := make([]int, 0, len(types))
typeNameMap := make(map[int]string, len(types))
for _, t := range types {
if t == nil {
continue
}
if filterTypeID > 0 && t.TypeID != filterTypeID {
continue
}
if filterTypeName != "" && !strings.Contains(t.TypeName, filterTypeName) {
continue
}
typeIDs = append(typeIDs, t.TypeID)
typeNameMap[t.TypeID] = t.TypeName
}
models, err := dao.Model.ListAll(ctx)
if err != nil {
return nil, err
}
itemsMap := map[int][]dto.ModelTypeModelItem{}
for _, m := range models {
if m == nil {
continue
}
form := normalizeFormValue(m.Form)
// 一个模型可能支持多个类型models_type="1,2,3"
for _, tid := range parseModelsTypeIDs(m.ModelsType) {
// 若请求过滤了类型,则只输出该类型
if filterTypeID > 0 && tid != filterTypeID {
continue
}
if filterTypeName != "" {
if _, ok := typeNameMap[tid]; !ok {
continue
}
}
itemsMap[tid] = append(itemsMap[tid], dto.ModelTypeModelItem{
ID: m.Id,
Name: m.ModelName,
Form: form,
})
}
}
out := make([]dto.ModelTypeWithModelsItem, 0, len(typeIDs))
for _, tid := range typeIDs {
items := itemsMap[tid]
if items == nil {
items = make([]dto.ModelTypeModelItem, 0)
}
out = append(out, dto.ModelTypeWithModelsItem{
TypeID: tid,
Type: typeNameMap[tid],
Items: items,
})
}
return out, nil
}

View File

@@ -1,52 +0,0 @@
package service
import (
"sort"
"strconv"
"strings"
)
// normalizeModelsType 将 "1, 2,2,3" 归一化为 "1,2,3"
// - 去空格
// - 去重
// - 升序排序
func normalizeModelsType(v string) string {
ids := parseModelsTypeIDs(v)
if len(ids) == 0 {
return ""
}
parts := make([]string, 0, len(ids))
for _, id := range ids {
parts = append(parts, strconv.Itoa(id))
}
return strings.Join(parts, ",")
}
// parseModelsTypeIDs 解析 models_type 字段(支持 "1,2,3"),返回去重后的 int 列表(升序)。
func parseModelsTypeIDs(v string) []int {
v = strings.TrimSpace(v)
if v == "" {
return nil
}
raw := strings.Split(v, ",")
seen := map[int]struct{}{}
out := make([]int, 0, len(raw))
for _, s := range raw {
s = strings.TrimSpace(s)
if s == "" || s == "0" {
continue
}
id, err := strconv.Atoi(s)
if err != nil || id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
sort.Ints(out)
return out
}

View File

@@ -13,12 +13,12 @@ var Stat = &statService{}
func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res *dto.ListModelStatRes, err error) {
pageNum, pageSize := 1, 10
if req != nil && req.Page != nil {
if req.Page.PageNum > 0 {
pageNum = int(req.Page.PageNum)
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.Page.PageSize > 0 {
pageSize = int(req.Page.PageSize)
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
startDay, endDay := "", ""
@@ -37,4 +37,3 @@ func (s *statService) List(ctx context.Context, req *dto.ListModelStatReq) (res
}
return &dto.ListModelStatRes{List: list, Total: total}, nil
}

View File

@@ -62,7 +62,6 @@ func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, dat
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
fmt.Println("打印结果 resp:", resp)
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}

View File

@@ -58,9 +58,10 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
State: 0,
BizName: req.BizName,
CallbackURL: req.CallbackUrl,
ModelKey: req.ModelKey,
ModelKey: m.ApiKey,
InputRef: req.InputRef,
RequestPayload: storedPayload,
EpicycleId: req.EpicycleId,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
@@ -80,7 +81,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
apiPath = r.URL.Path
httpMethod = r.Method
}
_, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{
_, _ = dao.OpLog.Insert(ctx, &entity.LogsModelOp{
IP: ip,
UserAgent: ua,
APIPath: apiPath,
@@ -97,9 +98,80 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
"taskId": taskID,
},
})
// 4) 创建成功后立即异步尝试执行当前任务,并仅在任务仍处于 pending(state=0) 时做定向轮询。
// 一旦任务进入 running/success/failed/downloaded就停止轮询避免一直空转。
go s.pollAndRunUntilPicked(context.WithoutCancel(ctx), taskID, req.EpicycleId)
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
// pollAndRunUntilPicked 用于 createTask 创建后的“轻量级定向轮询”:
// - 目标:尽快把刚创建的任务拉起来执行
// - 只在任务仍为 pending(state=0) 时继续尝试抢占
// - 一旦任务进入 running(1) / success(2) / failed(3) / downloaded(4),立即停止
// - 这样不会无限轮询runWork 仍负责处理积压队列和未处理到的任务
func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string, epicycleId int64) {
if taskID == "" {
return
}
interval := g.Cfg().MustGet(ctx, "asynch.worker.intervalSeconds").Int()
if interval <= 0 {
interval = 5
}
g.Log().Infof(ctx, "[task-auto-run][start] taskId=%s interval=%ds", taskID, interval)
ticker := time.NewTicker(time.Duration(interval) * time.Second)
defer ticker.Stop()
tryRun := func() bool {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
return true
}
if t == nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=task_not_found", taskID)
return true
}
switch t.State {
case 0:
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
}
return false
case 1:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=running", taskID)
return true
case 2, 3, 4:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=terminal state=%d", taskID, t.State)
return true
default:
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=unknown_state state=%d", taskID, t.State)
return true
}
}
// 先立即尝试一次
if stop := tryRun(); stop {
return
}
for {
select {
case <-ctx.Done():
g.Log().Infof(ctx, "[task-auto-run][stop] taskId=%s reason=context_done", taskID)
return
case <-ticker.C:
if stop := tryRun(); stop {
return
}
}
}
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
@@ -168,12 +240,12 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
pageNum, pageSize := 1, 10
if req != nil && req.Page != nil {
if req.Page.PageNum > 0 {
pageNum = int(req.Page.PageNum)
if req != nil {
if req.PageNum > 0 {
pageNum = req.PageNum
}
if req.Page.PageSize > 0 {
pageSize = int(req.Page.PageSize)
if req.PageSize > 0 {
pageSize = req.PageSize
}
}
modelName := ""

113
service/utils.go Normal file
View File

@@ -0,0 +1,113 @@
package service
import (
"encoding/json"
"strings"
"github.com/gogf/gf/v2/container/gvar"
)
func normalizeFormValue(v any) any {
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
if v == nil {
return []any{}
}
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
case []byte:
if len(t) == 0 {
return []any{}
}
return normalizeFormValueFromJSONBytes(t)
case *gvar.Var:
// goframe 常见的 DB 返回类型
if t == nil {
return []any{}
}
b := t.Bytes()
if len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
s := strings.TrimSpace(t.String())
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
default:
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
if vb, ok := v.(interface{ Bytes() []byte }); ok {
if b := vb.Bytes(); len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
}
if vs, ok := v.(interface{ String() string }); ok {
if s := strings.TrimSpace(vs.String()); s != "" {
return normalizeFormValueFromJSONString(s)
}
}
// 已经是 []any / map[string]any 等结构
return v
}
}
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
func normalizeFormValueFromJSONString(s string) any {
var out any
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
return []any{}
}
// 如果解出来还是 string且看起来是 JSON再解一层
if inner, ok := out.(string); ok {
inner = strings.TrimSpace(inner)
if inner == "" {
return []any{}
}
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
var out2 any
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
return out2
}
}
return []any{}
}
return out
}
func normalizeFormValueFromJSONBytes(b []byte) any {
var out any
if err := json.Unmarshal(b, &out); err != nil || out == nil {
return []any{}
}
// bytes 解出来也可能是 string同上
if inner, ok := out.(string); ok {
return normalizeFormValueFromJSONString(inner)
}
return out
}
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}

View File

@@ -5,12 +5,14 @@ import (
"fmt"
"strings"
"time"
"unicode/utf8"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/tidwall/gjson"
)
var AsyncWorker = &asyncWorker{}
@@ -43,7 +45,7 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
for _, t := range tasks {
task := t
_ = pool.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
w.handleOne(ctx, task, 0)
done <- struct{}{}
}, func(ctx context.Context, e error) {
if e != nil {
@@ -59,8 +61,23 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
return claimed, nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
// - 只定向抢占当前 taskId 对应的 pending 任务
// - 若任务已被其它 worker 抢走/已不在 pending则直接返回
func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId int64) error {
task, err := dao.Task.ClaimPendingByTaskIDGlobal(ctx, taskID)
if err != nil {
return err
}
if task == nil {
return nil
}
w.handleOne(ctx, task, epicycleId)
return nil
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
// 从任务入库的 request_payload 里恢复 payload + headers
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
@@ -71,26 +88,42 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
errMsg := "模型不存在或未启用"
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, errMsg)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = errMsg
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
// 2) 分布式并发限制(按 model_name 全局维度)
// 2) 分布式并发限制
semKey := fmt.Sprintf("asynch:sem:%s", t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
leaseSeconds := int64(3600)
maxC := GetRuntimeMaxConcurrency(ctx, t.ModelName, m.MaxConcurrency)
acquired, err := acquireSemaphore(ctx, semKey, maxC, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
// 并发满了:放回排队,不回调(不是失败)
_ = w.rollbackToPending(ctx, t.Id)
return
}
@@ -109,30 +142,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
data []byte
contentType string
ext string
textResult string
)
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
} else {
// 临时文件不可用:回退重新调用模型
data = nil
}
}
if data == nil {
// 统计:仅在真正请求模型时 +1OSS 重试不计入)
// 统计
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(t.TenantId), t.Creator, t.ModelName)
// 核心调用
data, err = InvokeModel(ctx, m, payload, t.ModelKey)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
contentType, ext = DetectFileType(data)
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
if len(textResult) > 20000 {
textResult = textResult[:20000]
}
}
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
if err == nil && tmpPath != "" {
t.TmpFile = tmpPath
@@ -147,26 +190,46 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// ============ OSS失败不回调还会重试 ============
// 注意OSS失败保留临时文件下次重试所以这里不触发最终回调
// 如果已经重试多次还没成功,需要在任务超时或超过最大重试次数时才回调失败
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
if err := dao.Task.UpdateSuccessGlobal(
ctx,
t.Id,
ossURL,
fileType,
textResult,
int64(len(data)),
nil,
GetExpendTokens(m.TokenMapping, textResult),
); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
// 成功/失败均不再占用 queue_limitstate=0/1 才占用)
// 成功/失败均不再占用 queue_limit
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
// 6) 成功回调(不影响主流程)
// 6) 成功回调
t.State = 2
t.OssFile = ossURL
t.FileType = fileType
go triggerSuccessCallback(context.WithoutCancel(ctx), t)
t.TextResult = textResult
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
go triggerCallback(context.WithoutCancel(ctx), t)
// ============ 如果有 epicycleId也触发业务回调 ============
if epicycleId != 0 {
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
}
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
}
@@ -174,3 +237,13 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
func GetExpendTokens(tokenMapping string, textResult string) int {
value := gjson.Get(textResult, tokenMapping)
if value.Exists() {
return int(value.Int())
} else {
return len(textResult)
}
}