244 lines
7.6 KiB
Go
244 lines
7.6 KiB
Go
|
|
package eino
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"rag/consts/model"
|
|||
|
|
"rag/dao"
|
|||
|
|
"rag/model/dto"
|
|||
|
|
"rag/model/entity"
|
|||
|
|
|
|||
|
|
"gitea.com/red-future/common/jaeger"
|
|||
|
|
"gitea.com/red-future/common/utils"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/ark"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/arkbot"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/claude"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/deepseek"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/ollama"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/qianfan"
|
|||
|
|
"github.com/cloudwego/eino-ext/components/model/qwen"
|
|||
|
|
modelChat "github.com/cloudwego/eino/components/model"
|
|||
|
|
"github.com/gogf/gf/v2/frame/g"
|
|||
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type ChatModelSet struct {
|
|||
|
|
Ark *ark.ChatModel
|
|||
|
|
ArkBot *arkbot.ChatModel
|
|||
|
|
Claude *claude.ChatModel
|
|||
|
|
DeepSeek *deepseek.ChatModel
|
|||
|
|
Ollama *ollama.ChatModel
|
|||
|
|
OpenAI *openai.ChatModel
|
|||
|
|
Qianfan *qianfan.ChatModel
|
|||
|
|
Qwen *qwen.ChatModel
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 全局租户容器:key=tenantId,value=该租户的对话模型
|
|||
|
|
var tenantChatModels = make(map[uint64]*ChatModelSet)
|
|||
|
|
|
|||
|
|
func init() {
|
|||
|
|
ctx := context.Background()
|
|||
|
|
ctx, span := jaeger.NewSpan(ctx, "InitAllChat")
|
|||
|
|
defer span.End()
|
|||
|
|
InitAllChat(ctx)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ===================== 1. 服务启动时:初始化所有租户对话模型 =====================
|
|||
|
|
func InitAllChat(ctx context.Context) {
|
|||
|
|
list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
|
|||
|
|
ModelType: model.ModelTypeChat.Code(),
|
|||
|
|
})
|
|||
|
|
if err != nil {
|
|||
|
|
g.Log().Errorf(ctx, "获取所有租户对话模型失败: %v", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, l := range list {
|
|||
|
|
err = InitChat(ctx, l)
|
|||
|
|
if err != nil {
|
|||
|
|
g.Log().Errorf(ctx, "初始化租户[%v]的对话模型失败: %v", l.TenantId, err)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func InitChat(ctx context.Context, modelDO *entity.Model) (err error) {
|
|||
|
|
set := &ChatModelSet{}
|
|||
|
|
switch *modelDO.ConfigType {
|
|||
|
|
case *model.ModelConfigTypeChatArk.Code():
|
|||
|
|
var cfg entity.ChatModelConfigArk
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析Ark配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
set.Ark, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxTokens: gconv.PtrInt(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatArkBot.Code():
|
|||
|
|
var cfg entity.ChatModelConfigArkBot
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析ArkBot配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
set.ArkBot, err = arkbot.NewChatModel(ctx, &arkbot.Config{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxTokens: gconv.PtrInt(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatClaude.Code():
|
|||
|
|
var cfg entity.ChatModelConfigClaude
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析Claude配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
claudeCfg := claude.Config{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
BaseURL: gconv.PtrString(cfg.BaseURL),
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxTokens: gconv.Int(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
ByBedrock: cfg.ByBedrock,
|
|||
|
|
AccessKey: cfg.AccessKey,
|
|||
|
|
SecretAccessKey: cfg.SecretAccessKey,
|
|||
|
|
Region: cfg.Region,
|
|||
|
|
}
|
|||
|
|
set.Claude, err = claude.NewChatModel(ctx, &claudeCfg)
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatDeepSeek.Code():
|
|||
|
|
var cfg entity.ChatModelConfigDeepSeek
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析DeepSeek配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
set.DeepSeek, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
BaseURL: cfg.BaseURL,
|
|||
|
|
Temperature: gconv.Float32(0.7),
|
|||
|
|
MaxTokens: gconv.Int(1024),
|
|||
|
|
TopP: gconv.Float32(1.0),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatOllama.Code():
|
|||
|
|
var cfg entity.ChatModelConfigOllama
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析Ollama配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
set.Ollama, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
|
|||
|
|
BaseURL: cfg.BaseURL,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatOpenAI.Code():
|
|||
|
|
var cfg entity.ChatModelConfigOpenAI
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析OpenAI配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
openAiCfg := openai.ChatModelConfig{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
ByAzure: cfg.ByAzure,
|
|||
|
|
BaseURL: cfg.BaseURL,
|
|||
|
|
APIVersion: cfg.APIVersion,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxCompletionTokens: gconv.PtrInt(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
}
|
|||
|
|
set.OpenAI, err = openai.NewChatModel(ctx, &openAiCfg)
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatQianfan.Code():
|
|||
|
|
var cfg entity.ChatModelConfigQianfan
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析千帆配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
qcfg := qianfan.GetQianfanSingletonConfig()
|
|||
|
|
qcfg.AccessKey = cfg.AccessKey
|
|||
|
|
qcfg.SecretKey = cfg.SecretKey
|
|||
|
|
set.Qianfan, err = qianfan.NewChatModel(ctx, &qianfan.ChatModelConfig{
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxCompletionTokens: gconv.PtrInt(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
case *model.ModelConfigTypeChatQwen.Code():
|
|||
|
|
var cfg entity.ChatModelConfigQwen
|
|||
|
|
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
|
|||
|
|
return fmt.Errorf("解析Qwen配置失败: %v", err)
|
|||
|
|
}
|
|||
|
|
set.Qwen, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
|||
|
|
APIKey: cfg.APIKey,
|
|||
|
|
Model: cfg.Model,
|
|||
|
|
BaseURL: cfg.BaseURL,
|
|||
|
|
Temperature: gconv.PtrFloat32(0.7),
|
|||
|
|
MaxTokens: gconv.PtrInt(1024),
|
|||
|
|
TopP: gconv.PtrFloat32(1.0),
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
default:
|
|||
|
|
return fmt.Errorf("不支持的对话模型类型: %v", *modelDO.ConfigType)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("初始化对话模型失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 无锁存入租户 map
|
|||
|
|
tenantChatModels[modelDO.TenantId] = set
|
|||
|
|
g.Log().Infof(ctx, "租户[%v]对话模型[%v]初始化成功", modelDO.TenantId, *modelDO.ConfigType)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func GetTenantChatModel(tenantId uint64) (*ChatModelSet, error) {
|
|||
|
|
set := tenantChatModels[tenantId]
|
|||
|
|
if set == nil {
|
|||
|
|
return nil, fmt.Errorf("租户[%v]对话模型未初始化", tenantId)
|
|||
|
|
}
|
|||
|
|
return set, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func GetTenantChatModelByType(ctx context.Context, configType model.ModelConfigType) (modelChat.BaseChatModel, error) {
|
|||
|
|
userInfo, err := utils.GetUserInfo(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
set, err := GetTenantChatModel(userInfo.TenantId)
|
|||
|
|
if set == nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch *configType {
|
|||
|
|
case *model.ModelConfigTypeChatArk.Code():
|
|||
|
|
return set.Ark, nil
|
|||
|
|
case *model.ModelConfigTypeChatArkBot.Code():
|
|||
|
|
return set.ArkBot, nil
|
|||
|
|
case *model.ModelConfigTypeChatClaude.Code():
|
|||
|
|
return set.Claude, nil
|
|||
|
|
case *model.ModelConfigTypeChatDeepSeek.Code():
|
|||
|
|
return set.DeepSeek, nil
|
|||
|
|
case *model.ModelConfigTypeChatOllama.Code():
|
|||
|
|
return set.Ollama, nil
|
|||
|
|
case *model.ModelConfigTypeChatOpenAI.Code():
|
|||
|
|
return set.OpenAI, nil
|
|||
|
|
case *model.ModelConfigTypeChatQianfan.Code():
|
|||
|
|
return set.Qianfan, nil
|
|||
|
|
case *model.ModelConfigTypeChatQwen.Code():
|
|||
|
|
return set.Qwen, nil
|
|||
|
|
default:
|
|||
|
|
return nil, fmt.Errorf("不支持的对话模型类型: %v", configType)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func RefreshTenantChatModel(ctx context.Context, modelDO *entity.Model) error {
|
|||
|
|
delete(tenantChatModels, modelDO.TenantId)
|
|||
|
|
return InitChat(ctx, modelDO)
|
|||
|
|
}
|