300 lines
9.4 KiB
Go
300 lines
9.4 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"rag/common/eino"
|
||
"rag/consts/model"
|
||
"rag/consts/task"
|
||
"rag/dao"
|
||
"rag/model/dto"
|
||
"rag/model/entity"
|
||
|
||
"github.com/gogf/gf/v2/errors/gerror"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
var ModelService = new(modelService)
|
||
|
||
type modelService struct{}
|
||
|
||
// GetModelAllEnums 获取模型全量枚举(模型类型 + 配置类型 合并)
|
||
func (s *modelService) GetModelAllEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
|
||
_, _ = ctx, req
|
||
res = new(dto.GetModelEnumRes)
|
||
|
||
// 获取所有模型类型
|
||
modelTypeRes := model.GetAllModelTypeEnums()
|
||
|
||
var options []dto.ModelEnumOption
|
||
for _, mt := range modelTypeRes.Options {
|
||
// 构造 modelType
|
||
modelTypeStr := gconv.String(mt.Key)
|
||
modelType := model.ModelType(gconv.PtrString(modelTypeStr))
|
||
|
||
// 获取对应配置类型
|
||
configRes := model.GetAllModelConfigTypeEnums(modelType)
|
||
|
||
// 把 configRes.Options 转成目标类型
|
||
var configList []dto.ModelKeyValue
|
||
err = gconv.Structs(configRes.Options, &configList)
|
||
if err != nil {
|
||
return
|
||
}
|
||
options = append(options, dto.ModelEnumOption{
|
||
Key: mt.Key,
|
||
Value: mt.Value,
|
||
ConfigTypes: configList,
|
||
})
|
||
}
|
||
|
||
res.Options = options
|
||
return
|
||
}
|
||
|
||
func (s *modelService) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (*dto.GetModelConfigFormFieldsRes, error) {
|
||
_ = ctx
|
||
|
||
fields := make([]map[string]interface{}, 0)
|
||
|
||
// ===================== 固定基础字段(CreateModelReq 前4个)=====================
|
||
// 1. 模型类型:固定只读字段
|
||
fields = append(fields, map[string]interface{}{
|
||
"name": "modelType",
|
||
"label": "模型类型",
|
||
"type": "text",
|
||
"disabled": true,
|
||
"required": true,
|
||
"value": model.GetModelTypeDescByCode(req.ModelType),
|
||
})
|
||
|
||
var configTypeValue = "未知类型"
|
||
if *req.ModelType == *model.ModelTypeVector.Code() {
|
||
configTypeValue = model.GetVectorDescByCode(req.ConfigType)
|
||
} else if *req.ModelType == *model.ModelTypeChat.Code() {
|
||
configTypeValue = model.GetChatDescByCode(req.ConfigType)
|
||
}
|
||
|
||
// 2. 配置类型:固定只读字段
|
||
fields = append(fields, map[string]interface{}{
|
||
"name": "configType",
|
||
"label": "配置类型",
|
||
"type": "text",
|
||
"disabled": true,
|
||
"required": true,
|
||
"value": configTypeValue,
|
||
})
|
||
|
||
// 3. 基础信息
|
||
fields = append(fields, []map[string]interface{}{
|
||
{
|
||
"name": "modelName",
|
||
"label": "模型名称",
|
||
"type": "input",
|
||
"required": true,
|
||
"placeholder": "例如:DeepSeek 对话模型",
|
||
},
|
||
{
|
||
"name": "modelDesc",
|
||
"label": "模型描述",
|
||
"type": "textarea",
|
||
"required": false,
|
||
},
|
||
}...)
|
||
|
||
// 4. 通用模型名称字段
|
||
fields = append(fields, map[string]interface{}{
|
||
"name": "model",
|
||
"label": "模型类型",
|
||
"type": "input",
|
||
"required": true,
|
||
"placeholder": "例如:deepseek-chat / text-embedding-3-small",
|
||
})
|
||
|
||
// ===================== 动态配置内容 ConfigContent =====================
|
||
|
||
// 根据模型类型 + 配置类型生成动态字段
|
||
switch *req.ModelType {
|
||
case *model.ModelTypeChat.Code():
|
||
switch *req.ConfigType {
|
||
case *model.ModelConfigTypeChatArk.Code():
|
||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||
|
||
case *model.ModelConfigTypeChatArkBot.Code():
|
||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||
|
||
case *model.ModelConfigTypeChatClaude.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "by_bedrock", "label": "使用 AWS Bedrock", "type": "switch", "default": true},
|
||
{"name": "access_key", "label": "Access Key", "type": "input"},
|
||
{"name": "secret_access_key", "label": "Secret Access Key", "type": "input"},
|
||
{"name": "region", "label": "Region", "type": "input"},
|
||
{"name": "api_key", "label": "API Key", "type": "input"},
|
||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeChatDeepSeek.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||
{"name": "base_url", "label": "Base URL", "type": "input", "default": "https://api.deepseek.com"},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeChatOllama.Code():
|
||
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
|
||
|
||
case *model.ModelConfigTypeChatOpenAI.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
|
||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||
{"name": "api_version", "label": "API Version", "type": "input"},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeChatQianfan.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
|
||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeChatQwen.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||
}...)
|
||
}
|
||
|
||
case *model.ModelTypeVector.Code():
|
||
switch *req.ConfigType {
|
||
case *model.ModelConfigTypeVectorArk.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||
{"name": "api_type", "label": "API Type", "type": "input"},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeVectorOllama.Code():
|
||
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
|
||
|
||
case *model.ModelConfigTypeVectorOpenAI.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
|
||
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
|
||
{"name": "base_url", "label": "Base URL", "type": "input"},
|
||
{"name": "api_version", "label": "API Version", "type": "input"},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeVectorQianfan.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
|
||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||
}...)
|
||
|
||
case *model.ModelConfigTypeVectorTencentCloud.Code():
|
||
fields = append(fields, []map[string]interface{}{
|
||
{"name": "secret_id", "label": "Secret ID", "type": "input", "required": true},
|
||
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
|
||
{"name": "region", "label": "Region", "type": "input", "required": true, "default": "ap-beijing"},
|
||
}...)
|
||
case *model.ModelConfigTypeVectorDashScope.Code():
|
||
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
|
||
}
|
||
}
|
||
|
||
return &dto.GetModelConfigFormFieldsRes{
|
||
ModelType: req.ModelType,
|
||
ConfigType: req.ConfigType,
|
||
Fields: fields,
|
||
}, nil
|
||
}
|
||
|
||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||
count, err := dao.Model.Count(ctx, &dto.GetModelReq{
|
||
ModelType: req.ModelType,
|
||
})
|
||
if err != nil {
|
||
return
|
||
}
|
||
if count > 0 {
|
||
err = gerror.New("模型配置已存在")
|
||
return
|
||
}
|
||
var id int64
|
||
id, err = dao.Model.Insert(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
res = &dto.CreateModelRes{Id: id}
|
||
err = s.refresh(ctx, id)
|
||
return
|
||
}
|
||
|
||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) (err error) {
|
||
count, err := dao.Task.Count(ctx, &dto.GetTaskReq{
|
||
TaskStatus: task.TaskStatusRunning,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if !g.IsEmpty(count) {
|
||
err = gerror.New("任务正在执行中,模型配置暂时不可修改,请稍后再试")
|
||
return
|
||
}
|
||
var updateCount int64
|
||
updateCount, err = dao.Model.Update(ctx, req)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if !g.IsEmpty(updateCount) {
|
||
err = s.refresh(ctx, req.Id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func (s *modelService) refresh(ctx context.Context, id int64) (err error) {
|
||
var modelDO *entity.Model
|
||
modelDO, err = dao.Model.Get(ctx, &dto.GetModelReq{
|
||
Id: id,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if *modelDO.ModelType == *model.ModelTypeChat.Code() {
|
||
if err = eino.RefreshTenantChatModel(ctx, modelDO); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
if *modelDO.ModelType == *model.ModelTypeVector.Code() {
|
||
if err = eino.RefreshTenantEmbedder(ctx, modelDO); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) (err error) {
|
||
_, err = dao.Model.Delete(ctx, req)
|
||
return
|
||
}
|
||
|
||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
|
||
r, err := dao.Model.Get(ctx, req)
|
||
err = gconv.Struct(r, &res)
|
||
return
|
||
}
|
||
|
||
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
|
||
list, total, err := dao.Model.List(ctx, req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
res = &dto.ListModelRes{
|
||
Total: total,
|
||
}
|
||
err = gconv.Struct(list, &res.List)
|
||
return
|
||
}
|