refactor(model): 重构模型实体和数据访问层
This commit is contained in:
@@ -2,8 +2,10 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"model-gateway/model/dto"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
@@ -34,9 +36,12 @@ type AutoTuneResult struct {
|
||||
// - 基于吞吐与 P90 执行耗时估算 max_concurrency 的运行时值(不超过 cap)
|
||||
// - queue_limit 与 expected_seconds 绑定(允许排队时间 = expected_seconds * 2),生成运行时值(不超过 cap)
|
||||
// - 单次调整幅度限制 ±50%,写入 Redis(带 TTL)
|
||||
func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error) {
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 3600
|
||||
func AutoTune(ctx context.Context, req *dto.AutoTuneReq) (res *dto.AutoTuneRes, err error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("request cannot be nil")
|
||||
}
|
||||
if req.WindowSeconds <= 0 {
|
||||
req.WindowSeconds = 3600 // 默认1小时
|
||||
}
|
||||
// 1) 读取模型配置(cap),按 model_name 聚合去重(如果表里有多租户重复数据,取较大上限)
|
||||
var modelRows []*entity.AsynchModel
|
||||
@@ -68,7 +73,7 @@ func AutoTune(ctx context.Context, windowSeconds int) ([]AutoTuneResult, error)
|
||||
}
|
||||
}
|
||||
if len(modelMap) == 0 {
|
||||
return []AutoTuneResult{}, nil
|
||||
return nil, errors.New("no models found")
|
||||
}
|
||||
|
||||
// 2) 统计指定窗口:按 model_name 计算 cnt 和 P90 执行耗时
|
||||
@@ -89,7 +94,7 @@ SELECT model_name,
|
||||
AND finished_at IS NOT NULL
|
||||
AND finished_at >= (NOW() - (? || ' seconds')::interval)
|
||||
GROUP BY model_name`, public.TableNameTask)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, windowSeconds)
|
||||
r, err := gfdb.DB(ctx).GetAll(ctx, sql, req.WindowSeconds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -189,6 +194,8 @@ SELECT model_name,
|
||||
})
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), windowSeconds)
|
||||
return out, nil
|
||||
g.Log().Infof(ctx, "[auto_tune] done models=%d windowSeconds=%d", len(out), req.WindowSeconds)
|
||||
return &dto.AutoTuneRes{
|
||||
List: out,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// triggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func triggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
callbackURL := t.BizName + t.CallbackURL
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// triggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func triggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := forwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = http.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/model/dto"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -14,14 +16,14 @@ var Cleaner = &cleaner{}
|
||||
type cleaner struct{}
|
||||
|
||||
// RunOnce 由上层定时任务触发:执行一次清理/重试
|
||||
func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
func (c *cleaner) RunOnce(ctx context.Context) (res *dto.CleanWorkRes, err error) {
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
@@ -82,11 +84,14 @@ func (c *cleaner) RunOnce(ctx context.Context) {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
// 重试耗尽硬删除:释放闸门占位(兜底,若此前已释放则幂等)
|
||||
ReleaseQueueSlot(ctx, t.ModelName, t.TaskID)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
}
|
||||
return &dto.CleanWorkRes{
|
||||
Ok: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,47 +1 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
// gateway.DetectContentType 可能带 charset 等参数:text/plain; charset=utf-8
|
||||
if idx := strings.Index(ct, ";"); idx > 0 {
|
||||
ct = strings.TrimSpace(ct[:idx])
|
||||
}
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
case "text/plain":
|
||||
return ct, ".txt"
|
||||
case "application/json":
|
||||
return ct, ".json"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
sub := parts[1]
|
||||
// 避免出现 "plain; charset=utf-8" 之类的后缀
|
||||
if idx := strings.Index(sub, ";"); idx > 0 {
|
||||
sub = strings.TrimSpace(sub[:idx])
|
||||
}
|
||||
return ct, "." + sub
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
171
service/gateway/gateway_http_service.go
Normal file
171
service/gateway/gateway_http_service.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/entity"
|
||||
"time"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err = part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err = commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// TriggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] JSON序列化失败 taskId=%s 错误=%v", t.TaskID, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] JSON序列化失败 epicycleId=%d 错误=%v", epicycleId, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
}
|
||||
g.Log().Infof(ctx, "[提示词回调] 发送成功 epicycleId=%d 回调地址=%s 消息体大小=%d字节", t.EpicycleId, callbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
//// callback 向回调地址 POST 任务结果(与查询接口 GetTaskRes 出参一致)
|
||||
//func (s *audioTaskService) callback(ctx context.Context, taskID, status, errMsg, callbackURL string) {
|
||||
// if callbackURL == "" {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// task, _ := dao.TranscribeTask.GetByTaskID(ctx, taskID)
|
||||
// if task == nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 任务不存在", taskID)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// detailList, _ := dao.TranscribeTaskDetail.ListByTaskID(ctx, taskID)
|
||||
// detailItems := make([]dto.TranscribeTaskDetailItem, 0, len(detailList))
|
||||
// for i := range detailList {
|
||||
// detailItems = append(detailItems, dao.DetailEntityToItem(&detailList[i]))
|
||||
// }
|
||||
//
|
||||
// // 构建与查询接口一致的 taskInfo
|
||||
// taskInfo := dao.EntityToItem(task)
|
||||
//
|
||||
// // 兼容历史数据: 从 result 中补全 scenes 等字段
|
||||
// detailItems = enrichDetailsFromResult(task.Result, detailItems)
|
||||
//
|
||||
// payload := dto.CallbackPayload{
|
||||
// TaskInfo: taskInfo,
|
||||
// DetailList: detailItems,
|
||||
// }
|
||||
//
|
||||
// body, _ := json.Marshal(payload)
|
||||
//
|
||||
// // 透传调用方的用户信息
|
||||
// userJSON, _ := json.Marshal(beans.User{UserName: "admin", TenantId: 1})
|
||||
//
|
||||
// req, _ := http.NewRequest("POST", callbackURL, bytes.NewReader(body))
|
||||
// req.Header.Set("Content-Type", "application/json")
|
||||
// req.Header.Set("X-User-Info", string(userJSON))
|
||||
//
|
||||
// resp, reqErr := http.DefaultClient.Do(req)
|
||||
// if reqErr != nil {
|
||||
// g.Log().Errorf(ctx, "[回调 %s] 请求失败: %v", taskID, reqErr)
|
||||
// return
|
||||
// }
|
||||
// defer resp.Body.Close()
|
||||
//
|
||||
// respBody, _ := io.ReadAll(resp.Body)
|
||||
// g.Log().Infof(ctx, "[回调 %s] 响应 status=%d, body=%s", taskID, resp.StatusCode, string(respBody))
|
||||
//}
|
||||
@@ -1,53 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
@@ -3,13 +3,15 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/model/entity"
|
||||
"model-gateway/service/gateway"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -20,28 +22,20 @@ var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
|
||||
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := forwardHeaders(ctx)
|
||||
var r = make(map[string]bool)
|
||||
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
// 获取当前会话模型
|
||||
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
|
||||
var model *entity.AsynchModel
|
||||
model, err = dao.Model.GetByIsChatModel(ctx)
|
||||
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 如果有会话模型,那就改变为 0
|
||||
if model != nil {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: model.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, req)
|
||||
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
//根据当前 isChatModel 来判断是否更新模型
|
||||
if req.IsChatModel == gconv.PtrInt(1) {
|
||||
//判断当前用户是否有会话模型
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
}
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if admin {
|
||||
req.IsOwner = gconv.PtrInt(0)
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新
|
||||
var count int
|
||||
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
|
||||
ID: req.ID,
|
||||
Creator: user.UserName,
|
||||
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
if model.TenantId == 1 {
|
||||
insertDto := new(dto.CreateModelReq)
|
||||
err = gconv.Struct(req, insertDto)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Insert(ctx, insertDto)
|
||||
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
_, err = dao.Model.Update(ctx, req)
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
ModelName: req.ModelName,
|
||||
ModelType: req.ModelType,
|
||||
BaseURL: req.BaseURL,
|
||||
HttpMethod: req.HttpMethod,
|
||||
HeadMsg: req.HeadMsg,
|
||||
Form: req.Form,
|
||||
RequestMapping: req.RequestMapping,
|
||||
ResponseMapping: req.ResponseMapping,
|
||||
ResponseBody: req.ResponseBody,
|
||||
ResponseTokenField: req.ResponseTokenField,
|
||||
IsPrivate: req.IsPrivate,
|
||||
IsChatModel: req.IsChatModel,
|
||||
ApiKey: req.ApiKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
ExpectedSeconds: req.ExpectedSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id string) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
|
||||
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.Get(ctx, id)
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||||
var models []*entity.AsynchModel
|
||||
|
||||
req.IsOwner = gconv.PtrInt(1)
|
||||
admin, err := s.IsSuperAdmin(ctx)
|
||||
admin, err := gateway.IsSuperAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []
|
||||
var user *beans.User
|
||||
user, err = utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
req.Creator = user.UserName
|
||||
|
||||
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 处理列表中每条记录的 JSONB 字段
|
||||
for _, m := range models {
|
||||
m.Form = ParseJSONField(m.Form)
|
||||
m.RequestMapping = ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = ParseJSONField(m.ResponseBody)
|
||||
m.Form = util.ParseJSONField(m.Form)
|
||||
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
|
||||
}
|
||||
return models, total, nil
|
||||
return &dto.ListModelRes{
|
||||
List: models,
|
||||
Total: total,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetModelTypesFromConfig 从配置文件读取模型类型
|
||||
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
|
||||
typeMap := make(map[int]string)
|
||||
|
||||
// 读取配置
|
||||
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
|
||||
for k, v := range configMap {
|
||||
typeID := gconv.Int(k)
|
||||
typeName := gconv.String(v)
|
||||
if typeID > 0 && typeName != "" {
|
||||
typeMap[typeID] = typeName
|
||||
}
|
||||
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
|
||||
// 返回副本,避免外部修改
|
||||
types := make(map[int]string, len(public.ModelTypeName))
|
||||
for k, v := range public.ModelTypeName {
|
||||
types[k] = v
|
||||
}
|
||||
// 如果配置为空,使用默认值
|
||||
if len(typeMap) == 0 {
|
||||
typeMap = map[int]string{
|
||||
1: "推理模型",
|
||||
2: "图片模型",
|
||||
3: "音频模型",
|
||||
4: "向量化模型",
|
||||
5: "全模态模型",
|
||||
}
|
||||
}
|
||||
return typeMap
|
||||
return &dto.TypeItem{
|
||||
Type: types,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
|
||||
// 校验新会话模型是否存在
|
||||
newModel, err := dao.Model.Get(ctx, req.Id)
|
||||
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newModel == nil {
|
||||
return errors.New("新会话模型不存在")
|
||||
}
|
||||
|
||||
// 获取当前用户会话模型
|
||||
currentModel, err := dao.Model.GetByIsChatModel(ctx)
|
||||
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
|
||||
// 如果点击的就是当前会话模型(已经是1),取消它(设为0)
|
||||
if currentModel.Id != req.Id {
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: currentModel.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(0),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
}
|
||||
|
||||
// 设置当前为会话模型(设为1)
|
||||
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
|
||||
ID: req.Id,
|
||||
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
|
||||
IsChatModel: gconv.PtrInt(1),
|
||||
})
|
||||
return err
|
||||
@@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
|
||||
model, err := dao.Model.GetByIsChatModel(ctx)
|
||||
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
IsChatModel: new(1),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
model.Form = ParseJSONField(model.Form)
|
||||
model.RequestMapping = ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = ParseJSONField(model.ResponseBody)
|
||||
return model, nil
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
return &dto.GetIsChatModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package service
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
)
|
||||
|
||||
// StorageService 结果存储(OSS/MinIO)抽象
|
||||
type StorageService interface {
|
||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
||||
}
|
||||
|
||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
||||
var Storage StorageService = &ossStorage{}
|
||||
|
||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
||||
@@ -1,81 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
type ossStorage struct{}
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := forwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"time"
|
||||
|
||||
"model-gateway/dao"
|
||||
@@ -21,13 +21,13 @@ var Task = &taskService{}
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
fmt.Printf("打印请求:%+v", req)
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
ctx = util.AsyncCtx(ctx)
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
"headers": util.ForwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
@@ -127,7 +127,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
defer ticker.Stop()
|
||||
|
||||
tryRun := func() bool {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][stop] taskId=%s reason=query_failed err=%v", taskID, err)
|
||||
return true
|
||||
@@ -138,7 +140,7 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
switch t.State {
|
||||
case 0:
|
||||
if err := AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
if err = AsyncWorker.RunByTaskID(ctx, taskID, epicycleId); err != nil {
|
||||
g.Log().Warningf(ctx, "[task-auto-run][retry] taskId=%s state=0 err=%v", taskID, err)
|
||||
} else {
|
||||
g.Log().Infof(ctx, "[task-auto-run][triggered] taskId=%s state=0", taskID)
|
||||
@@ -175,7 +177,9 @@ func (s *taskService) pollAndRunUntilPicked(ctx context.Context, taskID string,
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
t, err := dao.Task.Get(ctx, &entity.AsynchTask{
|
||||
TaskID: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -209,7 +213,9 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
m, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
ModelName: t.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func loadTmpResult(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func deleteTmpResult(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
|
||||
113
service/utils.go
113
service/utils.go
@@ -1,113 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
)
|
||||
|
||||
func normalizeFormValue(v any) any {
|
||||
// 目标:对外永远返回 JSON 数组/对象,而不是字符串。
|
||||
if v == nil {
|
||||
return []any{}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(t)
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONBytes(t)
|
||||
case *gvar.Var:
|
||||
// goframe 常见的 DB 返回类型
|
||||
if t == nil {
|
||||
return []any{}
|
||||
}
|
||||
b := t.Bytes()
|
||||
if len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
s := strings.TrimSpace(t.String())
|
||||
if s == "" {
|
||||
return []any{}
|
||||
}
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
default:
|
||||
// 尝试兼容其他“像 JSON 的值类型”(例如实现了 Bytes/String 的包装类型)
|
||||
if vb, ok := v.(interface{ Bytes() []byte }); ok {
|
||||
if b := vb.Bytes(); len(b) > 0 {
|
||||
return normalizeFormValueFromJSONBytes(b)
|
||||
}
|
||||
}
|
||||
if vs, ok := v.(interface{ String() string }); ok {
|
||||
if s := strings.TrimSpace(vs.String()); s != "" {
|
||||
return normalizeFormValueFromJSONString(s)
|
||||
}
|
||||
}
|
||||
// 已经是 []any / map[string]any 等结构
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容“JSONB 里存了 JSON 字符串”的历史数据:
|
||||
// 例如 form_json = '"[]"' 或 '"[{...}]"'(外层是字符串,内层才是数组/对象)
|
||||
func normalizeFormValueFromJSONString(s string) any {
|
||||
var out any
|
||||
if err := json.Unmarshal([]byte(s), &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// 如果解出来还是 string,且看起来是 JSON,再解一层
|
||||
if inner, ok := out.(string); ok {
|
||||
inner = strings.TrimSpace(inner)
|
||||
if inner == "" {
|
||||
return []any{}
|
||||
}
|
||||
if strings.HasPrefix(inner, "[") || strings.HasPrefix(inner, "{") {
|
||||
var out2 any
|
||||
if err := json.Unmarshal([]byte(inner), &out2); err == nil && out2 != nil {
|
||||
return out2
|
||||
}
|
||||
}
|
||||
return []any{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFormValueFromJSONBytes(b []byte) any {
|
||||
var out any
|
||||
if err := json.Unmarshal(b, &out); err != nil || out == nil {
|
||||
return []any{}
|
||||
}
|
||||
// bytes 解出来也可能是 string(同上)
|
||||
if inner, ok := out.(string); ok {
|
||||
return normalizeFormValueFromJSONString(inner)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ParseJSONField(field any) any {
|
||||
var v *gvar.Var
|
||||
switch val := field.(type) {
|
||||
case *gvar.Var:
|
||||
v = val
|
||||
default:
|
||||
return field
|
||||
}
|
||||
|
||||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
|
||||
str := v.String()
|
||||
var result any
|
||||
if json.Unmarshal([]byte(str), &result) == nil {
|
||||
return result
|
||||
}
|
||||
return str
|
||||
}
|
||||
@@ -2,7 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/model/dto"
|
||||
"model-gateway/service/gateway"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
@@ -23,24 +29,23 @@ type asyncWorker struct {
|
||||
// RunOnce 由上层定时任务触发:一次性抢占并处理一批任务
|
||||
// - batchSize: 本次抢占数量
|
||||
// - goroutines: 本次并发数(协程池大小)
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (claimed int, err error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10
|
||||
func (w *asyncWorker) RunOnce(ctx context.Context, req *dto.RunWorkReq) (res *dto.RunWorkRes, err error) {
|
||||
if req.BatchSize <= 0 {
|
||||
req.BatchSize = 10
|
||||
}
|
||||
if goroutines <= 0 {
|
||||
goroutines = 1
|
||||
if req.Goroutines <= 0 {
|
||||
req.Goroutines = 1
|
||||
}
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, req.BatchSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return nil, err
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
return 0, nil
|
||||
return nil, errors.New("no task to run")
|
||||
}
|
||||
pool := grpool.New(goroutines)
|
||||
pool := grpool.New(req.Goroutines)
|
||||
defer pool.Close()
|
||||
|
||||
claimed = len(tasks)
|
||||
claimed := len(tasks)
|
||||
done := make(chan struct{}, claimed)
|
||||
for _, t := range tasks {
|
||||
task := t
|
||||
@@ -58,7 +63,9 @@ func (w *asyncWorker) RunOnce(ctx context.Context, batchSize, goroutines int) (c
|
||||
for i := 0; i < claimed; i++ {
|
||||
<-done
|
||||
}
|
||||
return claimed, nil
|
||||
return &dto.RunWorkRes{
|
||||
Claimed: claimed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RunByTaskID 创建任务后立即异步尝试执行当前任务:
|
||||
@@ -78,9 +85,9 @@ func (w *asyncWorker) RunByTaskID(ctx context.Context, taskID string, epicycleId
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
payload, headers := util.ParseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
ctx = util.SetTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
@@ -91,7 +98,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -102,7 +109,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = errMsg
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -118,7 +125,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
@@ -147,9 +154,9 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
data, err = os.ReadFile(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
} else {
|
||||
data = nil
|
||||
}
|
||||
@@ -165,11 +172,11 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
// ============ 失败回调 ============
|
||||
t.State = 3
|
||||
t.ErrorMsg = err.Error()
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ================================
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
contentType, ext = util.DetectFileType(data)
|
||||
if utf8.Valid(data) && (strings.HasPrefix(contentType, "text/") || contentType == "application/json") {
|
||||
textResult = string(data)
|
||||
}
|
||||
@@ -182,7 +189,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
}
|
||||
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
ossURL, err := gateway.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
@@ -198,7 +205,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(
|
||||
if err = dao.Task.UpdateSuccessGlobal(
|
||||
ctx,
|
||||
t.Id,
|
||||
ossURL,
|
||||
@@ -206,7 +213,7 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
textResult,
|
||||
int64(len(data)),
|
||||
nil,
|
||||
GetExpendTokens(m.TokenMapping, textResult),
|
||||
GetExpendTokens(m.ResponseTokenField, textResult),
|
||||
); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
@@ -221,14 +228,33 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, epicy
|
||||
t.FileType = fileType
|
||||
t.TextResult = textResult
|
||||
g.Log().Infof(ctx, "[CALLBACK][DISPATCH] taskId=%s bizName=%s callbackUrl=%s", t.TaskID, t.BizName, t.CallbackURL)
|
||||
go triggerCallback(context.WithoutCancel(ctx), t)
|
||||
go gateway.TriggerCallback(context.WithoutCancel(ctx), t)
|
||||
// ============ 如果有 epicycleId,也触发业务回调 ============
|
||||
if epicycleId != 0 {
|
||||
go triggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
go gateway.TriggerPromptsCallback(context.WithoutCancel(ctx), t, epicycleId)
|
||||
}
|
||||
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
@@ -240,7 +266,6 @@ func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
} else {
|
||||
return len(textResult)
|
||||
}
|
||||
return len(textResult)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user