Files
rag/service/model.go

300 lines
9.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"rag/common/eino"
"rag/consts/model"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelService = new(modelService)
type modelService struct{}
// GetModelAllEnums 获取模型全量枚举(模型类型 + 配置类型 合并)
func (s *modelService) GetModelAllEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
_, _ = ctx, req
res = new(dto.GetModelEnumRes)
// 获取所有模型类型
modelTypeRes := model.GetAllModelTypeEnums()
var options []dto.ModelEnumOption
for _, mt := range modelTypeRes.Options {
// 构造 modelType
modelTypeStr := gconv.String(mt.Key)
modelType := model.ModelType(gconv.PtrString(modelTypeStr))
// 获取对应配置类型
configRes := model.GetAllModelConfigTypeEnums(modelType)
// 把 configRes.Options 转成目标类型
var configList []dto.ModelKeyValue
err = gconv.Structs(configRes.Options, &configList)
if err != nil {
return
}
options = append(options, dto.ModelEnumOption{
Key: mt.Key,
Value: mt.Value,
ConfigTypes: configList,
})
}
res.Options = options
return
}
func (s *modelService) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (*dto.GetModelConfigFormFieldsRes, error) {
_ = ctx
fields := make([]map[string]interface{}, 0)
// ===================== 固定基础字段CreateModelReq 前4个=====================
// 1. 模型类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "modelType",
"label": "模型类型",
"type": "text",
"disabled": true,
"required": true,
"value": model.GetModelTypeDescByCode(req.ModelType),
})
var configTypeValue = "未知类型"
if *req.ModelType == *model.ModelTypeVector.Code() {
configTypeValue = model.GetVectorDescByCode(req.ConfigType)
} else if *req.ModelType == *model.ModelTypeChat.Code() {
configTypeValue = model.GetChatDescByCode(req.ConfigType)
}
// 2. 配置类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "configType",
"label": "配置类型",
"type": "text",
"disabled": true,
"required": true,
"value": configTypeValue,
})
// 3. 基础信息
fields = append(fields, []map[string]interface{}{
{
"name": "modelName",
"label": "模型名称",
"type": "input",
"required": true,
"placeholder": "例如DeepSeek 对话模型",
},
{
"name": "modelDesc",
"label": "模型描述",
"type": "textarea",
"required": false,
},
}...)
// 4. 通用模型名称字段
fields = append(fields, map[string]interface{}{
"name": "model",
"label": "模型类型",
"type": "input",
"required": true,
"placeholder": "例如deepseek-chat / text-embedding-3-small",
})
// ===================== 动态配置内容 ConfigContent =====================
// 根据模型类型 + 配置类型生成动态字段
switch *req.ModelType {
case *model.ModelTypeChat.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatArkBot.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatClaude.Code():
fields = append(fields, []map[string]interface{}{
{"name": "by_bedrock", "label": "使用 AWS Bedrock", "type": "switch", "default": true},
{"name": "access_key", "label": "Access Key", "type": "input"},
{"name": "secret_access_key", "label": "Secret Access Key", "type": "input"},
{"name": "region", "label": "Region", "type": "input"},
{"name": "api_key", "label": "API Key", "type": "input"},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
case *model.ModelConfigTypeChatDeepSeek.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input", "default": "https://api.deepseek.com"},
}...)
case *model.ModelConfigTypeChatOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeChatOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeChatQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeChatQwen.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
}
case *model.ModelTypeVector.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "api_type", "label": "API Type", "type": "input"},
}...)
case *model.ModelConfigTypeVectorOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeVectorOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeVectorQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeVectorTencentCloud.Code():
fields = append(fields, []map[string]interface{}{
{"name": "secret_id", "label": "Secret ID", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
{"name": "region", "label": "Region", "type": "input", "required": true, "default": "ap-beijing"},
}...)
case *model.ModelConfigTypeVectorDashScope.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
}
}
return &dto.GetModelConfigFormFieldsRes{
ModelType: req.ModelType,
ConfigType: req.ConfigType,
Fields: fields,
}, nil
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
count, err := dao.Model.Count(ctx, &dto.GetModelReq{
ModelType: req.ModelType,
})
if err != nil {
return
}
if count > 0 {
err = gerror.New("模型配置已存在")
return
}
var id int64
id, err = dao.Model.Insert(ctx, req)
if err != nil {
return
}
res = &dto.CreateModelRes{Id: id}
err = s.refresh(ctx, id)
return
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) (err error) {
count, err := dao.Task.Count(ctx, &dto.GetTaskReq{
TaskStatus: task.TaskStatusRunning,
})
if err != nil {
return err
}
if !g.IsEmpty(count) {
err = gerror.New("任务正在执行中,模型配置暂时不可修改,请稍后再试")
return
}
var updateCount int64
updateCount, err = dao.Model.Update(ctx, req)
if err != nil {
return
}
if !g.IsEmpty(updateCount) {
err = s.refresh(ctx, req.Id)
if err != nil {
return err
}
}
return
}
func (s *modelService) refresh(ctx context.Context, id int64) (err error) {
var modelDO *entity.Model
modelDO, err = dao.Model.Get(ctx, &dto.GetModelReq{
Id: id,
})
if err != nil {
return err
}
if *modelDO.ModelType == *model.ModelTypeChat.Code() {
if err = eino.RefreshTenantChatModel(ctx, modelDO); err != nil {
return err
}
}
if *modelDO.ModelType == *model.ModelTypeVector.Code() {
if err = eino.RefreshTenantEmbedder(ctx, modelDO); err != nil {
return err
}
}
return
}
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) (err error) {
_, err = dao.Model.Delete(ctx, req)
return
}
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
r, err := dao.Model.Get(ctx, req)
err = gconv.Struct(r, &res)
return
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
list, total, err := dao.Model.List(ctx, req)
if err != nil {
return nil, err
}
res = &dto.ListModelRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}