Files

300 lines
9.4 KiB
Go
Raw Permalink Normal View History

package service
import (
"context"
"rag/common/eino"
"rag/consts/model"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelService = new(modelService)
type modelService struct{}
// GetModelAllEnums 获取模型全量枚举(模型类型 + 配置类型 合并)
func (s *modelService) GetModelAllEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
_, _ = ctx, req
res = new(dto.GetModelEnumRes)
// 获取所有模型类型
modelTypeRes := model.GetAllModelTypeEnums()
var options []dto.ModelEnumOption
for _, mt := range modelTypeRes.Options {
// 构造 modelType
modelTypeStr := gconv.String(mt.Key)
modelType := model.ModelType(gconv.PtrString(modelTypeStr))
// 获取对应配置类型
configRes := model.GetAllModelConfigTypeEnums(modelType)
// 把 configRes.Options 转成目标类型
var configList []dto.ModelKeyValue
err = gconv.Structs(configRes.Options, &configList)
if err != nil {
return
}
options = append(options, dto.ModelEnumOption{
Key: mt.Key,
Value: mt.Value,
ConfigTypes: configList,
})
}
res.Options = options
return
}
func (s *modelService) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (*dto.GetModelConfigFormFieldsRes, error) {
_ = ctx
fields := make([]map[string]interface{}, 0)
// ===================== 固定基础字段CreateModelReq 前4个=====================
// 1. 模型类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "modelType",
"label": "模型类型",
"type": "text",
"disabled": true,
"required": true,
"value": model.GetModelTypeDescByCode(req.ModelType),
})
var configTypeValue = "未知类型"
if *req.ModelType == *model.ModelTypeVector.Code() {
configTypeValue = model.GetVectorDescByCode(req.ConfigType)
} else if *req.ModelType == *model.ModelTypeChat.Code() {
configTypeValue = model.GetChatDescByCode(req.ConfigType)
}
// 2. 配置类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "configType",
"label": "配置类型",
"type": "text",
"disabled": true,
"required": true,
"value": configTypeValue,
})
// 3. 基础信息
fields = append(fields, []map[string]interface{}{
{
"name": "modelName",
"label": "模型名称",
"type": "input",
"required": true,
"placeholder": "例如DeepSeek 对话模型",
},
{
"name": "modelDesc",
"label": "模型描述",
"type": "textarea",
"required": false,
},
}...)
// 4. 通用模型名称字段
fields = append(fields, map[string]interface{}{
"name": "model",
"label": "模型类型",
"type": "input",
"required": true,
"placeholder": "例如deepseek-chat / text-embedding-3-small",
})
// ===================== 动态配置内容 ConfigContent =====================
// 根据模型类型 + 配置类型生成动态字段
switch *req.ModelType {
case *model.ModelTypeChat.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatArkBot.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatClaude.Code():
fields = append(fields, []map[string]interface{}{
{"name": "by_bedrock", "label": "使用 AWS Bedrock", "type": "switch", "default": true},
{"name": "access_key", "label": "Access Key", "type": "input"},
{"name": "secret_access_key", "label": "Secret Access Key", "type": "input"},
{"name": "region", "label": "Region", "type": "input"},
{"name": "api_key", "label": "API Key", "type": "input"},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
case *model.ModelConfigTypeChatDeepSeek.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input", "default": "https://api.deepseek.com"},
}...)
case *model.ModelConfigTypeChatOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeChatOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeChatQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeChatQwen.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
}
case *model.ModelTypeVector.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "api_type", "label": "API Type", "type": "input"},
}...)
case *model.ModelConfigTypeVectorOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeVectorOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeVectorQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeVectorTencentCloud.Code():
fields = append(fields, []map[string]interface{}{
{"name": "secret_id", "label": "Secret ID", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
{"name": "region", "label": "Region", "type": "input", "required": true, "default": "ap-beijing"},
}...)
case *model.ModelConfigTypeVectorDashScope.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
}
}
return &dto.GetModelConfigFormFieldsRes{
ModelType: req.ModelType,
ConfigType: req.ConfigType,
Fields: fields,
}, nil
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
count, err := dao.Model.Count(ctx, &dto.GetModelReq{
ModelType: req.ModelType,
})
if err != nil {
return
}
if count > 0 {
err = gerror.New("模型配置已存在")
return
}
var id int64
id, err = dao.Model.Insert(ctx, req)
if err != nil {
return
}
res = &dto.CreateModelRes{Id: id}
err = s.refresh(ctx, id)
return
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) (err error) {
count, err := dao.Task.Count(ctx, &dto.GetTaskReq{
TaskStatus: task.TaskStatusRunning,
})
if err != nil {
return err
}
if !g.IsEmpty(count) {
err = gerror.New("任务正在执行中,模型配置暂时不可修改,请稍后再试")
return
}
var updateCount int64
updateCount, err = dao.Model.Update(ctx, req)
if err != nil {
return
}
if !g.IsEmpty(updateCount) {
err = s.refresh(ctx, req.Id)
if err != nil {
return err
}
}
return
}
func (s *modelService) refresh(ctx context.Context, id int64) (err error) {
var modelDO *entity.Model
modelDO, err = dao.Model.Get(ctx, &dto.GetModelReq{
Id: id,
})
if err != nil {
return err
}
if *modelDO.ModelType == *model.ModelTypeChat.Code() {
if err = eino.RefreshTenantChatModel(ctx, modelDO); err != nil {
return err
}
}
if *modelDO.ModelType == *model.ModelTypeVector.Code() {
if err = eino.RefreshTenantEmbedder(ctx, modelDO); err != nil {
return err
}
}
return
}
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) (err error) {
_, err = dao.Model.Delete(ctx, req)
return
}
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
r, err := dao.Model.Get(ctx, req)
err = gconv.Struct(r, &res)
return
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
list, total, err := dao.Model.List(ctx, req)
if err != nil {
return nil, err
}
res = &dto.ListModelRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}