diff --git a/service/model_service.go b/service/model_service.go index 2fba6af..23bc5c5 100644 --- a/service/model_service.go +++ b/service/model_service.go @@ -25,8 +25,17 @@ type modelService struct{} func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) { // 获取当前会话模型 if !g.IsEmpty(req.IsChatModel) && *req.IsChatModel == 1 { + var user *beans.User + user, err = utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + // 获取当前用户会话模型 var model *entity.AsynchModel model, err = dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{ + Creator: user.UserName, + }, IsChatModel: new(1), }) if err != nil { @@ -88,8 +97,15 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error { //根据当前 isChatModel 来判断是否更新模型 if req.IsChatModel == gconv.PtrInt(1) { - //判断当前用户是否有会话模型 + user, err := utils.GetUserInfo(ctx) + if err != nil { + return err + } + // 获取当前用户会话模型 model, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{ + Creator: user.UserName, + }, IsChatModel: new(1), }) if err != nil { @@ -298,8 +314,16 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM if newModel == nil { return errors.New("新会话模型不存在") } + var user *beans.User + user, err = utils.GetUserInfo(ctx) + if err != nil { + return err + } // 获取当前用户会话模型 currentModel, err := dao.Model.Get(ctx, &entity.AsynchModel{ + SQLBaseDO: beans.SQLBaseDO{ + Creator: user.UserName, + }, IsChatModel: new(1), }) if err != nil { @@ -307,7 +331,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM } err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error { if !g.IsEmpty(currentModel) { - if currentModel.ModelType != 1 { + if currentModel.ModelType != public.ModelTypeInference { return errors.New("当前模型为非推理模型,不能设置为会话模型") } @@ -325,7 +349,7 @@ func (s *modelService) UpdateChatModel(ctx context.Context, req *dto.UpdateChatM // 设置当前为会话模型(设为1) _, err = dao.Model.Update(ctx, &entity.AsynchModel{ - SQLBaseDO: beans.SQLBaseDO{Id: currentModel.Id}, + SQLBaseDO: beans.SQLBaseDO{Id: req.Id}, IsChatModel: gconv.PtrInt(1), }) return err