refactor(model): 重构模型实体和数据访问层

This commit is contained in:
2026-05-21 10:41:37 +08:00
parent a080a5536d
commit 170568e03e
35 changed files with 903 additions and 1072 deletions

View File

@@ -3,13 +3,15 @@ package service
import (
"context"
"errors"
"model-gateway/common/util"
"model-gateway/consts/public"
"model-gateway/dao"
"model-gateway/model/dto"
"model-gateway/model/entity"
"model-gateway/service/gateway"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
@@ -20,28 +22,20 @@ var Model = &modelService{}
type modelService struct{}
// IsSuperAdmin 调用admin-go服务检查是否是超级管理员
func (s *modelService) IsSuperAdmin(ctx context.Context) (res bool, err error) {
headers := forwardHeaders(ctx)
var r = make(map[string]bool)
if err = http.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
return false, err
}
return r["isSuperAdmin"], err
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
// 获取当前会话模型
if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 {
var model *entity.AsynchModel
model, err = dao.Model.GetByIsChatModel(ctx)
model, err = dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
// 如果有会话模型,那就改变为 0
if model != nil {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: model.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: model.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
@@ -51,14 +45,40 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
if admin {
req.IsOwner = gconv.PtrInt(0)
}
id, err := dao.Model.Insert(ctx, req)
id, err := dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return nil, err
}
@@ -69,7 +89,9 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == gconv.PtrInt(1) {
//判断当前用户是否有会话模型
model, err := dao.Model.GetByIsChatModel(ctx)
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return err
}
@@ -79,68 +101,146 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
}
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return err
}
if admin {
req.IsOwner = gconv.PtrInt(0)
_, err = dao.Model.Update(ctx, req)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
if err != nil {
return err
}
return nil
}
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建否则更新
var count int
count, err = dao.Model.Count(ctx, &dto.GetModelReq{
ID: req.ID,
Creator: user.UserName,
model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return err
}
if count == 0 {
if model.TenantId == 1 {
insertDto := new(dto.CreateModelReq)
err = gconv.Struct(req, insertDto)
if err != nil {
return err
}
_, err = dao.Model.Insert(ctx, insertDto)
_, err = dao.Model.Insert(ctx, &entity.AsynchModel{
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
_, err = dao.Model.Update(ctx, req)
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
ModelName: req.ModelName,
ModelType: req.ModelType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
ResponseTokenField: req.ResponseTokenField,
IsPrivate: req.IsPrivate,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
IsOwner: req.IsOwner,
OperatorName: req.OperatorName,
TokenConfig: req.TokenConfig,
})
return err
}
func (s *modelService) Delete(ctx context.Context, id string) error {
_, err := dao.Model.DeleteByID(ctx, id)
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error {
_, err := dao.Model.Delete(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
model, err := dao.Model.Get(ctx, id)
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
})
if err != nil {
return nil, err
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
return &dto.GetModelRes{
Model: model,
}, nil
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
var models []*entity.AsynchModel
req.IsOwner = gconv.PtrInt(1)
admin, err := s.IsSuperAdmin(ctx)
admin, err := gateway.IsSuperAdmin(ctx)
if err != nil {
return
}
@@ -151,63 +251,55 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (list []
var user *beans.User
user, err = utils.GetUserInfo(ctx)
if err != nil {
return nil, 0, err
return nil, err
}
req.Creator = user.UserName
models, total, err = dao.Model.GetByCreatorAndPlatform(ctx, req)
models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req)
if err != nil {
return
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = ParseJSONField(m.Form)
m.RequestMapping = ParseJSONField(m.RequestMapping)
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
m.ResponseBody = ParseJSONField(m.ResponseBody)
m.Form = util.ParseJSONField(m.Form)
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
}
return models, total, nil
return &dto.ListModelRes{
List: models,
Total: total,
}, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
typeMap := make(map[int]string)
// 读取配置
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
for k, v := range configMap {
typeID := gconv.Int(k)
typeName := gconv.String(v)
if typeID > 0 && typeName != "" {
typeMap[typeID] = typeName
}
func GetModelTypesFromConfig() (res *dto.TypeItem, err error) {
// 返回副本,避免外部修改
types := make(map[int]string, len(public.ModelTypeName))
for k, v := range public.ModelTypeName {
types[k] = v
}
// 如果配置为空,使用默认值
if len(typeMap) == 0 {
typeMap = map[int]string{
1: "推理模型",
2: "图片模型",
3: "音频模型",
4: "向量化模型",
5: "全模态模型",
}
}
return typeMap
return &dto.TypeItem{
Type: types,
}, nil
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
// 校验新会话模型是否存在
newModel, err := dao.Model.Get(ctx, req.Id)
newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: req.Id},
})
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
// 获取当前用户会话模型
currentModel, err := dao.Model.GetByIsChatModel(ctx)
currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return err
}
@@ -219,8 +311,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel.Id != req.Id {
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: currentModel.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(0),
})
if err != nil {
@@ -230,8 +322,8 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
}
// 设置当前为会话模型设为1
_, err = dao.Model.Update(ctx, &dto.UpdateModelReq{
ID: req.Id,
_, err = dao.Model.Update(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id},
IsChatModel: gconv.PtrInt(1),
})
return err
@@ -239,17 +331,21 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
model, err := dao.Model.GetByIsChatModel(ctx)
func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) {
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
IsChatModel: new(1),
})
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
model.Form = util.ParseJSONField(model.Form)
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
return &dto.GetIsChatModelRes{
Model: model,
}, nil
}