feat: 支持多模型提供商 embedding
This commit is contained in:
8
rag/eino/consts.go
Normal file
8
rag/eino/consts.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package eino
|
||||
|
||||
const (
|
||||
providerArk = "ark"
|
||||
providerOpenai = "openai"
|
||||
providerQianfan = "qianfan"
|
||||
providerDashscope = "dashscope"
|
||||
)
|
||||
@@ -5,59 +5,60 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// 全局只初始化一次
|
||||
var (
|
||||
splitter document.Transformer
|
||||
)
|
||||
|
||||
// SemanticSplitDocument 语义分割文档
|
||||
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||
if g.IsEmpty(splitter) {
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
// 读取配置,使用合理的默认值
|
||||
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
|
||||
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
|
||||
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
|
||||
}
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
// 读取配置,使用合理的默认值
|
||||
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
|
||||
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
|
||||
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
|
||||
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
|
||||
}
|
||||
|
||||
// 使用批量包装器
|
||||
batchEmbedder := NewBatchEmbedder(Embedder, batchSize)
|
||||
// 使用批量包装器
|
||||
var batchEmbedder *BatchEmbedder
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
|
||||
case providerOpenai:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
|
||||
case providerDashscope:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
|
||||
}
|
||||
|
||||
splitter, err = semantic.NewSplitter(ctx, &semantic.Config{
|
||||
Embedding: batchEmbedder,
|
||||
BufferSize: bufferSize,
|
||||
Percentile: percentile,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
|
||||
Embedding: batchEmbedder,
|
||||
BufferSize: bufferSize,
|
||||
MinChunkSize: minChunkSize,
|
||||
Percentile: percentile,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return splitter.Transform(ctx, docs)
|
||||
}
|
||||
|
||||
// RecursiveSplitDocument 递归分割文档
|
||||
func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||
if g.IsEmpty(splitter) {
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
splitter, err = recursive.NewSplitter(ctx, &recursive.Config{
|
||||
ChunkSize: 1500,
|
||||
OverlapSize: 300,
|
||||
KeepType: recursive.KeepTypeNone,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
splitter, err := recursive.NewSplitter(ctx, &recursive.Config{
|
||||
ChunkSize: 512,
|
||||
OverlapSize: 100,
|
||||
KeepType: recursive.KeepTypeNone,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return splitter.Transform(ctx, docs)
|
||||
}
|
||||
|
||||
@@ -2,45 +2,68 @@ package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// 全局只初始化一次
|
||||
var (
|
||||
Embedder *dashscope.Embedder // 导出供其他模块使用
|
||||
EmbedderArk *ark.Embedder
|
||||
EmbedderDashscope *dashscope.Embedder
|
||||
EmbedderOpenAI *openai.Embedder
|
||||
)
|
||||
|
||||
// init:程序启动时自动执行一次
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
||||
var err error
|
||||
cfg := &dashscope.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
cfg := &ark.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||
apiTypeVal := ark.APIType(apiType)
|
||||
cfg.APIType = &apiTypeVal
|
||||
}
|
||||
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
||||
case providerOpenai:
|
||||
chatModelConfig := &openai.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
||||
case providerDashscope:
|
||||
cfg := &dashscope.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
||||
}
|
||||
// 检查是否配置了 APIType,支持 "text_api" 和 "multi_modal_api"
|
||||
//if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||
// apiTypeVal := dashscope.APIType(apiType)
|
||||
// cfg.APIType = &apiTypeVal
|
||||
//}
|
||||
Embedder, err = dashscope.NewEmbedder(ctx, cfg)
|
||||
if err != nil {
|
||||
glog.Fatalf("NewEmbedder of ark error: %v", err)
|
||||
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
||||
}
|
||||
//embedding, err := embedder.EmbedStrings(ctx, []string{"hello world", "bye bye"})
|
||||
//if err != nil {
|
||||
// log.Printf("embedding error: %v\n", err)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//log.Printf("embedding: %v\n", embedding)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
||||
return Embedder.EmbedStrings(ctx, texts)
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
return EmbedderArk.EmbedStrings(ctx, texts)
|
||||
case providerOpenai:
|
||||
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
||||
case providerDashscope:
|
||||
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user