package service import ( "context" "errors" "model-gateway/common/util" "model-gateway/consts/public" "model-gateway/dao" "model-gateway/model/dto" "model-gateway/model/entity" "model-gateway/service/gateway" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) var Model = &modelService{} type modelService struct{} func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { // 获取当前会话模型 if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 { var model *entity.AsynchModel model, err = dao.Model.Get(ctx, &entity.AsynchModel{ IsChatModel: new(1), }) if err != nil { return nil, err } // 如果有会话模型,那就改变为 0 if model != nil { _, err = dao.Model.Update(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: model.Id}, IsChatModel: gconv.PtrInt(0), }) if err != nil { return nil, err } } } req.IsOwner = gconv.PtrInt(1) admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return } if admin { req.IsOwner = gconv.PtrInt(0) } id, err := dao.Model.Insert(ctx, &entity.AsynchModel{ ModelName: req.ModelName, ModelType: req.ModelType, BaseURL: req.BaseURL, HttpMethod: req.HttpMethod, HeadMsg: req.HeadMsg, Form: req.Form, RequestMapping: req.RequestMapping, ResponseMapping: req.ResponseMapping, ResponseBody: req.ResponseBody, ResponseTokenField: req.ResponseTokenField, IsPrivate: req.IsPrivate, IsChatModel: req.IsChatModel, ApiKey: req.ApiKey, Enabled: req.Enabled, MaxConcurrency: req.MaxConcurrency, QueueLimit: req.QueueLimit, TimeoutSeconds: req.TimeoutSeconds, ExpectedSeconds: req.ExpectedSeconds, RetryTimes: req.RetryTimes, RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, AutoCleanSeconds: req.AutoCleanSeconds, Remark: req.Remark, IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, }) if err != nil { return nil, err } return &dto.CreateModelRes{ID: id}, nil } func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error { //根据当前 isChatModel 来判断是否更新模型 if req.IsChatModel == gconv.PtrInt(1) { //判断当前用户是否有会话模型 model, err := dao.Model.Get(ctx, &entity.AsynchModel{ IsChatModel: new(1), }) if err != nil { return err } if model != nil { return errors.New("用户已存在会话模型,不能创建") } } req.IsOwner = gconv.PtrInt(1) admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return err } if admin { req.IsOwner = gconv.PtrInt(0) _, err = dao.Model.Update(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, ModelName: req.ModelName, ModelType: req.ModelType, BaseURL: req.BaseURL, HttpMethod: req.HttpMethod, HeadMsg: req.HeadMsg, Form: req.Form, RequestMapping: req.RequestMapping, ResponseMapping: req.ResponseMapping, ResponseBody: req.ResponseBody, ResponseTokenField: req.ResponseTokenField, IsPrivate: req.IsPrivate, IsChatModel: req.IsChatModel, ApiKey: req.ApiKey, Enabled: req.Enabled, MaxConcurrency: req.MaxConcurrency, QueueLimit: req.QueueLimit, TimeoutSeconds: req.TimeoutSeconds, ExpectedSeconds: req.ExpectedSeconds, RetryTimes: req.RetryTimes, RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, AutoCleanSeconds: req.AutoCleanSeconds, Remark: req.Remark, IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, }) if err != nil { return err } return nil } // 判断当前传过来的模型id的模型是否是超级管理员的。如果是超管的进行创建,否则更新 model, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) if err != nil { return err } if model.TenantId == 1 { insertDto := new(dto.CreateModelReq) err = gconv.Struct(req, insertDto) if err != nil { return err } _, err = dao.Model.Insert(ctx, &entity.AsynchModel{ ModelName: req.ModelName, ModelType: req.ModelType, BaseURL: req.BaseURL, HttpMethod: req.HttpMethod, HeadMsg: req.HeadMsg, Form: req.Form, RequestMapping: req.RequestMapping, ResponseMapping: req.ResponseMapping, ResponseBody: req.ResponseBody, ResponseTokenField: req.ResponseTokenField, IsPrivate: req.IsPrivate, IsChatModel: req.IsChatModel, ApiKey: req.ApiKey, Enabled: req.Enabled, MaxConcurrency: req.MaxConcurrency, QueueLimit: req.QueueLimit, TimeoutSeconds: req.TimeoutSeconds, ExpectedSeconds: req.ExpectedSeconds, RetryTimes: req.RetryTimes, RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, AutoCleanSeconds: req.AutoCleanSeconds, Remark: req.Remark, IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, }) return err } _, err = dao.Model.Update(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, ModelName: req.ModelName, ModelType: req.ModelType, BaseURL: req.BaseURL, HttpMethod: req.HttpMethod, HeadMsg: req.HeadMsg, Form: req.Form, RequestMapping: req.RequestMapping, ResponseMapping: req.ResponseMapping, ResponseBody: req.ResponseBody, ResponseTokenField: req.ResponseTokenField, IsPrivate: req.IsPrivate, IsChatModel: req.IsChatModel, ApiKey: req.ApiKey, Enabled: req.Enabled, MaxConcurrency: req.MaxConcurrency, QueueLimit: req.QueueLimit, TimeoutSeconds: req.TimeoutSeconds, ExpectedSeconds: req.ExpectedSeconds, RetryTimes: req.RetryTimes, RetryQueueMaxSeconds: req.RetryQueueMaxSeconds, AutoCleanSeconds: req.AutoCleanSeconds, Remark: req.Remark, IsOwner: req.IsOwner, OperatorName: req.OperatorName, TokenConfig: req.TokenConfig, }) return err } func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) error { _, err := dao.Model.Delete(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) return err } func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) { model, err := dao.Model.Get(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.ID}, }) if err != nil { return nil, err } model.Form = util.ParseJSONField(model.Form) model.RequestMapping = util.ParseJSONField(model.RequestMapping) model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) model.ResponseBody = util.ParseJSONField(model.ResponseBody) model.TokenConfig = util.ParseJSONField(model.TokenConfig) return &dto.GetModelRes{ Model: model, }, nil } func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) { var models []*entity.AsynchModel req.IsOwner = gconv.PtrInt(1) admin, err := gateway.IsSuperAdmin(ctx) if err != nil { return } if admin { req.IsOwner = gconv.PtrInt(0) } var user *beans.User user, err = utils.GetUserInfo(ctx) if err != nil { return nil, err } req.Creator = user.UserName models, total, err := dao.Model.GetByCreatorAndPlatform(ctx, req) if err != nil { return } // 处理列表中每条记录的 JSONB 字段 for _, m := range models { m.Form = util.ParseJSONField(m.Form) m.RequestMapping = util.ParseJSONField(m.RequestMapping) m.ResponseMapping = util.ParseJSONField(m.ResponseMapping) m.ResponseBody = util.ParseJSONField(m.ResponseBody) m.TokenConfig = util.ParseJSONField(m.TokenConfig) } return &dto.ListModelRes{ List: models, Total: total, }, nil } // GetModelTypesFromConfig 从配置文件读取模型类型 func GetModelTypesFromConfig() (res *dto.TypeItem, err error) { // 返回副本,避免外部修改 types := make(map[int]string, len(public.ModelTypeName)) for k, v := range public.ModelTypeName { types[k] = v } return &dto.TypeItem{ Type: types, }, nil } func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatModelReq) error { // 校验新会话模型是否存在 newModel, err := dao.Model.GetByAcrossTenant(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, }) if err != nil { return err } if newModel == nil { return errors.New("新会话模型不存在") } // 获取当前用户会话模型 currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ IsChatModel: new(1), }) if err != nil { return err } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { if !g.IsEmpty(currentModel) { if currentModel.ModelType != 1 { return errors.New("当前模型为非推理模型,不能设置为会话模型") } // 如果点击的就是当前会话模型(已经是1),取消它(设为0) if currentModel.Id != req.Id { _, err = dao.Model.Update(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, IsChatModel: gconv.PtrInt(0), }) if err != nil { return err } } } // 设置当前为会话模型(设为1) _, err = dao.Model.Update(ctx, &entity.AsynchModel{ SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, IsChatModel: gconv.PtrInt(1), }) return err }) return err } func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelRes, error) { model, err := dao.Model.Get(ctx, &entity.AsynchModel{ IsChatModel: new(1), }) if err != nil { return nil, err } if model == nil { return nil, nil } model.Form = util.ParseJSONField(model.Form) model.RequestMapping = util.ParseJSONField(model.RequestMapping) model.ResponseMapping = util.ParseJSONField(model.ResponseMapping) model.ResponseBody = util.ParseJSONField(model.ResponseBody) model.TokenConfig = util.ParseJSONField(model.TokenConfig) return &dto.GetIsChatModelRes{ Model: model, }, nil }