refactor(model): 重构模型实体和数据访问层

This commit is contained in:
2026-05-21 10:41:37 +08:00
parent a080a5536d
commit 170568e03e
35 changed files with 903 additions and 1072 deletions

View File

@@ -2,8 +2,10 @@ package service
import (
"context"
"errors"
"fmt"
"math"
"model-gateway/model/dto"
"model-gateway/consts/public"
"model-gateway/model/entity"
@@ -34,9 +36,12 @@ type AutoTuneResult struct {
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2生成运行时值不超过 cap
// - 单次调整幅度限制 ±50%,写入 Redis带 TTL
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
if windowSeconds <= 0 {
windowSeconds = 3600
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
if req == nil {
return nil, errors.New("request cannot be nil")
}
if req.WindowSeconds <= 0 {
req.WindowSeconds = 3600 // 默认1小时
}
// 1) 读取模型配置cap按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
var modelRows []*entity.AsynchModel
@@ -68,7 +73,7 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
}
}
if len(modelMap) == 0 {
return []AutoTuneResult{}, nil
return nil, errors.New("no models found")
}
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
@@ -89,7 +94,7 @@ SELECT model_name,
AND finished_at IS NOT NULL
AND finished_at >= (NOW() - (? || ' seconds')::interval)
GROUP BY model_name`, public.TableNameTask)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
if err != nil {
return nil, err
}
@@ -189,6 +194,8 @@ SELECT model_name,
})
}
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
return out, nil
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
return &dto.AutoTuneRes{
List: out,
}, nil
}

View File

@@ -1,67 +0,0 @@
package service
import (
"context"
"encoding/json"
"model-gateway/model/entity"
"gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
)
// triggerCallback 任务成功后的回调:
// - JSON body 参数task_id/state/oss_file/file_type/text可选
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
callbackURL := t.BizName + t.CallbackURL
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"task_id": t.TaskID,
"state": t.State,
"oss_file": t.OssFile,
"file_type": t.FileType,
"text": t.TextResult,
"error_msg": t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
return
}
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.TaskID, callbackURL, len(headers), len(jsonData))
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
return
}
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
}
// triggerPromptsCallback 任务成功后的提示词回调
// - JSON body 参数epicycleId轮次id/textResult模型回答消息
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
callbackURL := "prompts-core/session/sessionCallback"
headers := forwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"epicycleId": epicycleId,
"text": t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.EpicycleId, callbackURL, len(headers), len(jsonData))
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
}

View File

@@ -2,6 +2,8 @@ package service
import (
"context"
"model-gateway/model/dto"
"os"
"time"
"model-gateway/dao"
@@ -14,14 +16,14 @@ var Cleaner = &cleaner{}
type cleaner struct{}
// RunOnce 由上层定时任务触发:执行一次清理/重试
func (c *cleaner) RunOnce(ctx context.Context) {
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
} else {
for _, t := range expired {
deleteTmpResult(t.TmpFile)
_ = os.Remove(t.TmpFile)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
@@ -82,11 +84,14 @@ func (c *cleaner) RunOnce(ctx context.Context) {
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
} else {
for _, t := range exhausted {
deleteTmpResult(t.TmpFile)
_ = os.Remove(t.TmpFile)
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
}
return &dto.CleanWorkRes{
Ok: true,
}, nil
}

View File

@@ -1,47 +1 @@
package service
import (
"net/http"
"strings"
)
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
// gateway.DetectContentType 可能带 charset 等参数text/plain; charset=utf-8
if idx := strings.Index(ct, ";"); idx > 0 {
ct = strings.TrimSpace(ct[:idx])
}
switch ct {
case "audio/mpeg":
return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav"
case "video/mp4":
return ct, ".mp4"
case "image/png":
return ct, ".png"
case "image/jpeg":
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
case "text/plain":
return ct, ".txt"
case "application/json":
return ct, ".json"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
sub := parts[1]
// 避免出现 "plain; charset=utf-8" 之类的后缀
if idx := strings.Index(sub, ";"); idx > 0 {
sub = strings.TrimSpace(sub[:idx])
}
return ct, "." + sub
}
return ct, ""
}
}

View File

@@ -0,0 +1,171 @@
package gateway
import (
"bytes"
"context"
"encoding/json"
"fmt"
"mime/multipart"
"model-gateway/common/util"
"model-gateway/model/entity"
"time"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/guid"
)
type uploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err = part.Write(data); err != nil {
return "", err
}
headers := util.ForwardHeaders(ctx)
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp uploadFileResponse
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}
// TriggerCallback 任务成功后的回调:
// - JSON body 参数task_id/state/oss_file/file_type/text可选
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
headers := util.ForwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"task_id": t.TaskID,
"state": t.State,
"oss_file": t.OssFile,
"file_type": t.FileType,
"text": t.TextResult,
"error_msg": t.ErrorMsg,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
return
}
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
return
}
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
}
// TriggerPromptsCallback 任务成功后的提示词回调
// - JSON body 参数epicycleId轮次id/textResult模型回答消息
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
callbackURL := "prompts-core/session/sessionCallback"
headers := util.ForwardHeaders(ctx)
var req struct{}
payload := map[string]interface{}{
"epicycleId": epicycleId,
"text": t.TextResult,
}
jsonData, err := json.Marshal(payload)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
t.EpicycleId, callbackURL, len(headers), len(jsonData))
err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData)
if err != nil {
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
return
}
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := util.ForwardHeaders(ctx)
var r = make(map[string]bool)
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
// if callbackURL == "" {
// return
// }
//
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
// if task == nil {
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
// return
// }
//
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
// for i := range detailList {
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
// }
//
// // 构建与查询接口一致的 taskInfo
// taskInfo := dao.EntityToItem(task)
//
// // 兼容历史数据: 从 result 中补全 scenes 等字段
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
//
// payload := dto.CallbackPayload{
// TaskInfo: taskInfo,
// DetailList: detailItems,
// }
//
// body, _ := json.Marshal(payload)
//
// // 透传调用方的用户信息
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
//
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
// req.Header.Set("Content-Type", "application/json")
// req.Header.Set("X-User-Info", string(userJSON))
//
// resp, reqErr := http.DefaultClient.Do(req)
// if reqErr != nil {
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
// return
// }
// defer resp.Body.Close()
//
// respBody, _ := io.ReadAll(resp.Body)
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
//}

View File

@@ -1,53 +0,0 @@
package service
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// asyncCtx 固化异步执行所需的 token/user避免请求结束后丢失仅在“同请求内起 goroutine”有用
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// 兜底:从请求头拿
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

View File

@@ -3,13 +3,15 @@ package service
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
@@ -20,28 +22,20 @@ var Model = &modelService{}
type modelService struct{}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := forwardHeaders(ctx)
var r = make(map[string]bool)
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var model *entity.AsynchModel
model, err = dao.Model.GetByIsChatModel(ctx)
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
// 如果有会话模型,那就改变为 0
if model != nil {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: model.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
@@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
id, err := dao.Model.Insert(ctx, req)
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return nil, err
}
@@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
//判断当前用户是否有会话模型
model, err := dao.Model.GetByIsChatModel(ctx)
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return err
}
@@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return err
}
if admin {
req.IsOwner = gconv.PtrInt(0)
_, err = dao.Model.Update(ctx, req)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return err
}
return nil
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建否则更新
var count int
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
ID: req.ID,
Creator: user.UserName,
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if count == 0 {
if model.TenantId == 1 {
insertDto := new(dto.CreateModelReq)
err = gconv.Struct(req, insertDto)
if err != nil {
return err
}
_, err = dao.Model.Insert(ctx, insertDto)
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
_, err = dao.Model.Update(ctx, req)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
func (s *modelService) Delete(ctx context.Context, id string) error {
_, err := dao.Model.DeleteByID(ctx, id)
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
model, err := dao.Model.Get(ctx, id)
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return nil, err
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
return &dto.GetModelRes{
Model: model,
}, nil
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
var models []*entity.AsynchModel
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
@@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, 0, err
return nil, err
}
req.Creator = user.UserName
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = ParseJSONField(m.Form)
m.RequestMapping = ParseJSONField(m.RequestMapping)
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
m.ResponseBody = ParseJSONField(m.ResponseBody)
m.Form = util.ParseJSONField(m.Form)
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
}
return models, total, nil
return &dto.ListModelRes{
List: models,
Total: total,
}, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
typeMap := make(map[int]string)
// 读取配置
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
for k, v := range configMap {
typeID := gconv.Int(k)
typeName := gconv.String(v)
if typeID > 0 && typeName != "" {
typeMap[typeID] = typeName
}
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
// 如果配置为空,使用默认值
if len(typeMap) == 0 {
typeMap = map[int]string{
1: "推理模型",
2: "图片模型",
3: "音频模型",
4: "向量化模型",
5: "全模态模型",
}
}
return typeMap
return &dto.TypeItem{
Type: types,
}, nil
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 校验新会话模型是否存在
newModel, err := dao.Model.Get(ctx, req.Id)
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
// 获取当前用户会话模型
currentModel, err := dao.Model.GetByIsChatModel(ctx)
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return err
}
@@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: currentModel.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
@@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
}
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: req.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
@@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
model, err := dao.Model.GetByIsChatModel(ctx)
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
return &dto.GetIsChatModelRes{
Model: model,
}, nil
}

View File

@@ -1,25 +0,0 @@
package service
import "github.com/gogf/gf/v2/util/gconv"
// parseStoredPayload 解析入库的 request_payload拆出模型调用 payload 与透传 headers
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
func parseStoredPayload(v any) (payload any, headers map[string]string) {
if v == nil {
return nil, nil
}
m := gconv.Map(v)
if len(m) == 0 {
return v, nil
}
if h, ok := m["headers"]; ok {
headers = gconv.MapStrStr(h)
}
if p, ok := m["payload"]; ok {
payload = p
} else {
payload = v
}
return
}

View File

@@ -1,18 +0,0 @@
package service
import (
"context"
"errors"
"model-gateway/model/entity"
)
// StorageService 结果存储OSS/MinIO抽象
type StorageService interface {
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
}
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO
var Storage StorageService = &ossStorage{}
var ErrStorageNotConfigured = errors.New("存储未配置")

View File

@@ -1,81 +0,0 @@
package service
import (
"bytes"
"context"
"fmt"
"mime/multipart"
"time"
"model-gateway/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/gogf/gf/v2/util/guid"
)
// 对接你们的 oss 文件服务POST oss/file/uploadFile (multipart/form-data)
type ossStorage struct{}
type uploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
part, err := writer.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return "", err
}
headers := forwardHeaders(ctx)
headers["Content-Type"] = contentType
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
var resp uploadFileResponse
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx给 worker 调 OSS 用
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
if v := gconv.String(headers["Authorization"]); v != "" {
ctx = context.WithValue(ctx, "token", v)
}
if v := gconv.String(headers["X-User-Info"]); v != "" {
ctx = context.WithValue(ctx, "xUserInfo", v)
}
return ctx
}

View File

@@ -3,7 +3,7 @@ package service
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"time"
"model-gateway/dao"
@@ -21,13 +21,13 @@ var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
fmt.Printf("打印请求:%+v", req)
startAt := time.Now()
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
ctx = util.AsyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
})
if err != nil {
return nil, err
}
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
"headers": util.ForwardHeaders(ctx),
}
t := &entity.AsynchTask{
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
defer ticker.Stop()
tryRun := func() bool {
t, err := dao.Task.GetByTaskID(ctx, taskID)
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
return true
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
switch t.State {
case 0:
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
} else {
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
TaskID: taskID,
})
if err != nil {
return nil, err
}
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
ModelName: t.ModelName,
})
if err != nil {
return nil, err
}

View File

@@ -1,38 +0,0 @@
package service
import (
"fmt"
"os"
"path/filepath"
)
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func loadTmpResult(path string) ([]byte, error) {
return os.ReadFile(path)
}
func deleteTmpResult(path string) {
if path == "" {
return
}
_ = os.Remove(path)
}

View File

@@ -1,113 +0,0 @@
package service
import (
"encoding/json"
"strings"
"github.com/gogf/gf/v2/container/gvar"
)
func normalizeFormValue(v any) any {
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
if v == nil {
return []any{}
}
switch t := v.(type) {
case string:
s := strings.TrimSpace(t)
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
case []byte:
if len(t) == 0 {
return []any{}
}
return normalizeFormValueFromJSONBytes(t)
case *gvar.Var:
// goframe 常见的 DB 返回类型
if t == nil {
return []any{}
}
b := t.Bytes()
if len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
s := strings.TrimSpace(t.String())
if s == "" {
return []any{}
}
return normalizeFormValueFromJSONString(s)
default:
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
if vb, ok := v.(interface{ Bytes() []byte }); ok {
if b := vb.Bytes(); len(b) > 0 {
return normalizeFormValueFromJSONBytes(b)
}
}
if vs, ok := v.(interface{ String() string }); ok {
if s := strings.TrimSpace(vs.String()); s != "" {
return normalizeFormValueFromJSONString(s)
}
}
// 已经是 []any / map[string]any 等结构
return v
}
}
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
func normalizeFormValueFromJSONString(s string) any {
var out any
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
return []any{}
}
// 如果解出来还是 string且看起来是 JSON再解一层
if inner, ok := out.(string); ok {
inner = strings.TrimSpace(inner)
if inner == "" {
return []any{}
}
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
var out2 any
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
return out2
}
}
return []any{}
}
return out
}
func normalizeFormValueFromJSONBytes(b []byte) any {
var out any
if err := json.Unmarshal(b, &out); err != nil || out == nil {
return []any{}
}
// bytes 解出来也可能是 string同上
if inner, ok := out.(string); ok {
return normalizeFormValueFromJSONString(inner)
}
return out
}
func ParseJSONField(field any) any {
var v *gvar.Var
switch val := field.(type) {
case *gvar.Var:
v = val
default:
return field
}
if v == nil || v.IsNil() || v.IsEmpty() {
return nil
}
str := v.String()
var result any
if json.Unmarshal([]byte(str), &result) == nil {
return result
}
return str
}

View File

@@ -2,7 +2,13 @@ package service
import (
"context"
"errors"
"fmt"
"model-gateway/common/util"
"model-gateway/model/dto"
"model-gateway/service/gateway"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
@@ -23,24 +29,23 @@ type asyncWorker struct {
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
// - batchSize: 本次抢占数量
// - goroutines: 本次并发数(协程池大小)
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
if batchSize <= 0 {
batchSize = 10
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
if req.BatchSize <= 0 {
req.BatchSize = 10
}
if goroutines <= 0 {
goroutines = 1
if req.Goroutines <= 0 {
req.Goroutines = 1
}
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
if err != nil {
return 0, err
return nil, err
}
if len(tasks) == 0 {
return 0, nil
return nil, errors.New("no task to run")
}
pool := grpool.New(goroutines)
pool := grpool.New(req.Goroutines)
defer pool.Close()
claimed = len(tasks)
claimed := len(tasks)
done := make(chan struct{}, claimed)
for _, t := range tasks {
task := t
@@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
for i := 0; i < claimed; i++ {
<-done
}
return claimed, nil
return &dto.RunWorkRes{
Claimed: claimed,
}, nil
}
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
@@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
// 从任务入库的 request_payload 里恢复 payload + headers
payload, headers := parseStoredPayload(t.RequestPayload)
payload, headers := util.ParseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
ctx = util.SetTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
@@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
@@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = errMsg
go triggerCallback(context.WithoutCancel(ctx), t)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
@@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
@@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
data, err = loadTmpResult(t.TmpFile)
data, err = os.ReadFile(t.TmpFile)
if err == nil && len(data) > 0 {
contentType, ext = DetectFileType(data)
contentType, ext = util.DetectFileType(data)
} else {
data = nil
}
@@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
// ============ 失败回调 ============
t.State = 3
t.ErrorMsg = err.Error()
go triggerCallback(context.WithoutCancel(ctx), t)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ================================
return
}
contentType, ext = DetectFileType(data)
contentType, ext = util.DetectFileType(data)
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
textResult = string(data)
}
@@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
}
// 4) 存储 OSS
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
@@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(
if err = dao.Task.UpdateSuccessGlobal(
ctx,
t.Id,
ossURL,
@@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
textResult,
int64(len(data)),
nil,
GetExpendTokens(m.TokenMapping, textResult),
GetExpendTokens(m.ResponseTokenField, textResult),
); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
@@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
t.FileType = fileType
t.TextResult = textResult
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
go triggerCallback(context.WithoutCancel(ctx), t)
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
// ============ 如果有 epicycleId也触发业务回调 ============
if epicycleId != 0 {
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
}
// 成功后清理临时文件
deleteTmpResult(t.TmpFile)
_ = os.Remove(t.TmpFile)
}
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
dir := filepath.Join(os.TempDir(), "model-asynch")
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
if err := os.WriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
@@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int {
value := gjson.Get(textResult, tokenMapping)
if value.Exists() {
return int(value.Int())
} else {
return len(textResult)
}
return len(textResult)
}