70 lines
2.1 KiB
Go
70 lines
2.1 KiB
Go
package eino
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
|
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
|
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
"github.com/golang/glog"
|
|
)
|
|
|
|
// 全局只初始化一次
|
|
var (
|
|
EmbedderArk *ark.Embedder
|
|
EmbedderDashscope *dashscope.Embedder
|
|
EmbedderOpenAI *openai.Embedder
|
|
)
|
|
|
|
func init() {
|
|
ctx := context.Background()
|
|
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
|
var err error
|
|
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
|
switch provider {
|
|
case providerArk:
|
|
cfg := &ark.EmbeddingConfig{
|
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
|
}
|
|
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
|
apiTypeVal := ark.APIType(apiType)
|
|
cfg.APIType = &apiTypeVal
|
|
}
|
|
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
|
case providerOpenai:
|
|
chatModelConfig := &openai.EmbeddingConfig{
|
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
|
}
|
|
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
|
case providerDashscope:
|
|
cfg := &dashscope.EmbeddingConfig{
|
|
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
|
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
|
}
|
|
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
|
}
|
|
if err != nil {
|
|
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
|
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
|
switch provider {
|
|
case providerArk:
|
|
return EmbedderArk.EmbedStrings(ctx, texts)
|
|
case providerOpenai:
|
|
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
|
case providerDashscope:
|
|
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
|
}
|
|
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
|
}
|