Files
model-gateway/service/model_service.go

274 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"errors"
"sort"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = &modelService{}
type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
m := &entity.AsynchModel{
ModelName: req.ModelName,
ModelsType: req.ModelsType,
BaseURL: req.BaseURL,
HttpMethod: req.HttpMethod,
HeadMsg: req.HeadMsg,
IsPrivate: req.IsPrivate,
Enabled: req.Enabled,
IsChatModel: req.IsChatModel,
ApiKey: req.ApiKey,
Form: req.Form,
RequestMapping: req.RequestMapping,
ResponseMapping: req.ResponseMapping,
ResponseBody: req.ResponseBody,
TokenMapping: req.TokenMapping,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutSeconds: req.TimeoutSeconds,
ExpectedSeconds: req.ExpectedSeconds,
RetryTimes: req.RetryTimes,
RetryQueueMaxSeconds: req.RetryQueueMaxSeconds,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
}
id, err := dao.Model.Insert(ctx, m)
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
//根据当前 isChatModel 来判断是否更新模型
if req.IsChatModel == 1 {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
//判断当前用户是否有会话模型
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return err
}
if model != nil {
return errors.New("用户已存在会话模型,不能创建新的会话模型")
}
_, err = dao.Model.Update(ctx, req)
return err
}
_, err := dao.Model.Update(ctx, req)
return err
}
func (s *modelService) Delete(ctx context.Context, id string) error {
_, err := dao.Model.DeleteByID(ctx, id)
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
model, err := dao.Model.Get(ctx, id)
if err != nil {
return nil, err
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, req *dto.ListModelReq) (list []*entity.AsynchModel, total int64, err error) {
isSuperAdmin, err := IsSuperAdmin(ctx)
if err != nil {
return nil, 0, err
}
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, 0, err
}
var models []*entity.AsynchModel
var count int64
if isSuperAdmin {
models, count, err = dao.Model.List(ctx, pageNum, pageSize, req.ModelName, req.ModelType, req.IsPrivate)
} else {
models, count, err = s.getModelsWithDedup(ctx, user.UserName, pageNum, pageSize, req.ModelName, req.ModelType, req.IsPrivate)
}
if err != nil {
return nil, 0, err
}
// 处理列表中每条记录的 JSONB 字段
for _, m := range models {
m.Form = ParseJSONField(m.Form)
m.RequestMapping = ParseJSONField(m.RequestMapping)
m.ResponseMapping = ParseJSONField(m.ResponseMapping)
m.ResponseBody = ParseJSONField(m.ResponseBody)
}
return models, count, nil
}
// getModelsWithDedup 获取普通用户的模型列表并去重
func (s *modelService) getModelsWithDedup(ctx context.Context, creator string, pageNum, pageSize int, modelNameLike string, modelType int, isPrivate int) (list []*entity.AsynchModel, total int64, err error) {
// 1. 查全量数据(不分页,便于去重)
allModels, err := dao.Model.GetByCreatorAndPlatform(ctx, creator, modelNameLike, modelType, isPrivate)
if err != nil {
return nil, 0, err
}
// 2. 按 modelName 去重,保留当前用户的
modelMap := make(map[string]*entity.AsynchModel)
for _, m := range allModels {
if m == nil {
continue
}
name := m.ModelName
_, ok := modelMap[name]
if !ok {
// 没有冲突,直接放进去
modelMap[name] = m
} else {
// 有冲突,保留当前用户创建的
if m.Creator == creator {
modelMap[name] = m
}
// 如果现有的就是当前用户的,不做任何替换
}
}
// 3. 转回切片并排序
deduped := make([]*entity.AsynchModel, 0, len(modelMap))
for _, m := range modelMap {
deduped = append(deduped, m)
}
sort.Slice(deduped, func(i, j int) bool {
return deduped[i].CreatedAt.After(deduped[j].CreatedAt)
})
// 4. 手动分页
total = int64(len(deduped))
if pageNum > 0 && pageSize > 0 {
start := (pageNum - 1) * pageSize
if start >= len(deduped) {
return []*entity.AsynchModel{}, total, nil
}
end := start + pageSize
if end > len(deduped) {
end = len(deduped)
}
deduped = deduped[start:end]
}
return deduped, total, nil
}
// GetModelTypesFromConfig 从配置文件读取模型类型
func GetModelTypesFromConfig(ctx context.Context) map[int]string {
typeMap := make(map[int]string)
// 读取配置
configMap := g.Cfg().MustGet(ctx, "modelType.types").Map()
for k, v := range configMap {
typeID := gconv.Int(k)
typeName := gconv.String(v)
if typeID > 0 && typeName != "" {
typeMap[typeID] = typeName
}
}
// 如果配置为空,使用默认值
if len(typeMap) == 0 {
typeMap = map[int]string{
1: "推理模型",
2: "图片模型",
3: "音频模型",
4: "向量化模型",
5: "全模态模型",
}
}
return typeMap
}
func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 校验新会话模型是否存在
newModel, err := dao.Model.Get(ctx, req.Id)
if err != nil {
return err
}
if newModel == nil {
return errors.New("新会话模型不存在")
}
// 获取当前用户会话模型
currentModel, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return err
}
if currentModel.ModelsType != 1 {
return errors.New("当前模型为非推理模型,不能设置为会话模型")
}
// 如果点击的就是当前会话模型已经是1取消它设为0
if currentModel != nil && currentModel.Id == req.Id {
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: req.Id,
IsChatModel: 0,
})
return err
}
// 如果之前有会话模型取消它设为0
if currentModel != nil {
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: currentModel.Id,
IsChatModel: 0,
})
if err != nil {
return err
}
}
// 设置当前为会话模型设为1
_, err = dao.Model.UpdateByID(ctx, &dto.UpdateModelReq{
ID: req.Id,
IsChatModel: 1,
})
return err
}
func (s *modelService) GetIsChatModel(ctx context.Context) (*entity.AsynchModel, error) {
user, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
model, err := dao.Model.GetByIsChatModel(ctx, user.UserName)
if err != nil {
return nil, err
}
if model == nil {
return nil, nil
}
model.Form = ParseJSONField(model.Form)
model.RequestMapping = ParseJSONField(model.RequestMapping)
model.ResponseMapping = ParseJSONField(model.ResponseMapping)
model.ResponseBody = ParseJSONField(model.ResponseBody)
return model, nil
}