feat: 支持多模型提供商 embedding

This commit is contained in:
2026-04-01 13:38:33 +08:00
parent bcbe6eba78
commit 2e4a0a89f1
9 changed files with 631 additions and 447 deletions

8
rag/eino/consts.go Normal file
View File

@@ -0,0 +1,8 @@
package eino
const (
providerArk = "ark"
providerOpenai = "openai"
providerQianfan = "qianfan"
providerDashscope = "dashscope"
)

View File

@@ -5,59 +5,60 @@ import (
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
)
// 全局只初始化一次
var (
splitter document.Transformer
)
// SemanticSplitDocument 语义分割文档
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
if g.IsEmpty(splitter) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 使用批量包装器
batchEmbedder := NewBatchEmbedder(Embedder, batchSize)
// 使用批量包装器
var batchEmbedder *BatchEmbedder
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
case providerOpenai:
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
case providerDashscope:
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
}
splitter, err = semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
BufferSize: bufferSize,
Percentile: percentile,
Separators: separators,
})
if err != nil {
return
}
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
BufferSize: bufferSize,
MinChunkSize: minChunkSize,
Percentile: percentile,
Separators: separators,
})
if err != nil {
return
}
return splitter.Transform(ctx, docs)
}
// RecursiveSplitDocument 递归分割文档
func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
if g.IsEmpty(splitter) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
splitter, err = recursive.NewSplitter(ctx, &recursive.Config{
ChunkSize: 1500,
OverlapSize: 300,
KeepType: recursive.KeepTypeNone,
Separators: separators,
})
if err != nil {
return
}
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
splitter, err := recursive.NewSplitter(ctx, &recursive.Config{
ChunkSize: 512,
OverlapSize: 100,
KeepType: recursive.KeepTypeNone,
Separators: separators,
})
if err != nil {
return
}
return splitter.Transform(ctx, docs)
}

View File

@@ -2,45 +2,68 @@ package eino
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/gogf/gf/v2/frame/g"
"github.com/golang/glog"
)
// 全局只初始化一次
var (
Embedder *dashscope.Embedder // 导出供其他模块使用
EmbedderArk *ark.Embedder
EmbedderDashscope *dashscope.Embedder
EmbedderOpenAI *openai.Embedder
)
// init程序启动时自动执行一次
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
var err error
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
cfg := &ark.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
apiTypeVal := ark.APIType(apiType)
cfg.APIType = &apiTypeVal
}
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
case providerOpenai:
chatModelConfig := &openai.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
case providerDashscope:
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
}
// 检查是否配置了 APIType支持 "text_api" 和 "multi_modal_api"
//if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
// apiTypeVal := dashscope.APIType(apiType)
// cfg.APIType = &apiTypeVal
//}
Embedder, err = dashscope.NewEmbedder(ctx, cfg)
if err != nil {
glog.Fatalf("NewEmbedder of ark error: %v", err)
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
}
//embedding, err := embedder.EmbedStrings(ctx, []string{"hello world", "bye bye"})
//if err != nil {
// log.Printf("embedding error: %v\n", err)
// return
//}
//
//log.Printf("embedding: %v\n", embedding)
}
return
}
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
return Embedder.EmbedStrings(ctx, texts)
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
return EmbedderArk.EmbedStrings(ctx, texts)
case providerOpenai:
return EmbedderOpenAI.EmbedStrings(ctx, texts)
case providerDashscope:
return EmbedderDashscope.EmbedStrings(ctx, texts)
}
return nil, fmt.Errorf("unsupported provider: %v", provider)
}

114
rag/gse/utils.go Normal file
View File

@@ -0,0 +1,114 @@
package gse
import (
"context"
"sort"
"github.com/go-ego/gse"
"github.com/go-ego/gse/hmm/extracker"
"github.com/go-ego/gse/hmm/segment"
"github.com/gogf/gf/v2/os/glog"
)
var GseTool *gseTool
// 初始化函数:程序启动时执行一次
func init() {
var err error
GseTool, err = newGseTool()
if err != nil {
glog.Error(context.Background(), err)
}
}
// gseTool 关键词提取工具gse v1.0.2 标准)
type gseTool struct {
seg gse.Segmenter
tfidf *extracker.TagExtracter
tr *extracker.TextRanker
}
// newGseTool 初始化工具(内置词典 + 停用词)
func newGseTool() (tool *gseTool, err error) {
// 1. 初始化分词器
var seg gse.Segmenter
// 内置词典(无外部文件)
err = seg.LoadDictEmbed()
if err != nil {
return
}
// 内置停用词v1.0.2 标准)
err = seg.LoadStopEmbed()
if err != nil {
return
}
// 2. 初始化 TF-IDF 提取器
tfidf := &extracker.TagExtracter{}
tfidf.WithGse(seg)
err = tfidf.LoadIdf()
if err != nil {
return
}
// 3. 初始化 TextRank 提取器
tr := &extracker.TextRanker{}
tr.WithGse(seg)
tool = &gseTool{
seg: seg,
tfidf: tfidf,
tr: tr,
}
return
}
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM
func (k *gseTool) Cut(text string) []string {
return k.seg.Cut(text, true)
}
// Keyword 最终输出:关键词 + 权重
type Keyword struct {
Word string `json:"word"`
Score float64 `json:"score"`
}
func (k *gseTool) Extract(text string, topN int) []Keyword {
// 1. 提取 TF-IDF
tfTags := k.extractTFIDF(text, topN)
// 2. 提取 TextRank
trTags := k.extractTextRank(text, topN)
// 3. 合并成最终关键词(业务最常用)
scoreMap := make(map[string]float64)
for _, tag := range tfTags {
scoreMap[tag.Text] = tag.Weight
}
for _, tag := range trTags {
scoreMap[tag.Text] = tag.Weight
}
// 转成切片并排序(高分在前)
res := make([]Keyword, 0, len(scoreMap))
for word, score := range scoreMap {
res = append(res, Keyword{Word: word, Score: score})
}
sort.Slice(res, func(i, j int) bool {
return res[i].Score > res[j].Score
})
return res
}
// ExtractTFIDF TF-IDF 关键词带权重90% 业务:文章标签、搜索、关键词
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
return k.tfidf.ExtractTags(text, topN)
}
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
return k.tr.TextRank(text, topN)
}