/* * Copyright 2024 Red Future Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package eino import ( "context" "fmt" "net/http" "time" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components" "github.com/cloudwego/eino/components/embedding" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/gclient" "github.com/gogf/gf/v2/util/gconv" ) var ( // 千问API默认配置 defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" defaultTimeout = 10 * time.Minute defaultRetryTimes = 2 ) type QwenEmbeddingConfig struct { // Timeout specifies the maximum duration to wait for API responses // Optional. Default: 10 minutes Timeout *time.Duration `json:"timeout"` // HTTPClient specifies the client to send HTTP requests. // Optional. Default &http.Client{Timeout: Timeout} HTTPClient *http.Client `json:"http_client"` // RetryTimes specifies the number of retry attempts for failed API calls // Optional. Default: 2 RetryTimes *int `json:"retry_times"` // BaseURL specifies the base URL for Qwen DashScope service // Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" BaseURL string `json:"base_url"` // APIKey specifies the API Key for authentication // Required APIKey string `json:"api_key"` // Model specifies the model name for Qwen embedding // Required. Examples: "text-embedding-v2", "text-embedding-v3" Model string `json:"model"` // TextType specifies the type of text: "document" or "query" // Optional. Default: "document" TextType string `json:"text_type"` // MaxConcurrentRequests specifies the maximum number of concurrent requests allowed // Optional. Default: 5 MaxConcurrentRequests *int `json:"max_concurrent_requests"` } type QwenEmbedder struct { client *gclient.Client conf *QwenEmbeddingConfig } // EmbeddingRequest 千问embedding请求结构 type EmbeddingRequest struct { Model string `json:"model"` Input struct { Texts []string `json:"texts"` } `json:"input"` Parameters struct { TextType string `json:"text_type,omitempty"` } `json:"parameters,omitempty"` } // EmbeddingResponse 千问embedding响应结构 type EmbeddingResponse struct { Output struct { Embeddings []struct { TextIndex int `json:"text_index"` Embedding []float64 `json:"embedding"` } `json:"embeddings"` } `json:"output"` Usage struct { TotalTokens int `json:"total_tokens"` } `json:"usage"` RequestID string `json:"request_id"` } type APIError struct { Code string `json:"code"` Message string `json:"message"` RequestID string `json:"request_id"` } func (e *APIError) Error() string { return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID) } func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client { if len(config.BaseURL) == 0 { config.BaseURL = defaultBaseURL } if config.Timeout == nil { config.Timeout = &defaultTimeout } if config.RetryTimes == nil { defaultRetryTimes := 2 config.RetryTimes = &defaultRetryTimes } if len(config.TextType) == 0 { config.TextType = "document" } if config.MaxConcurrentRequests == nil { defaultMaxConcurrentRequests := 5 config.MaxConcurrentRequests = &defaultMaxConcurrentRequests } client := g.Client() client.SetTimeout(*config.Timeout) return client } func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) { if len(config.APIKey) == 0 { return nil, fmt.Errorf("[Qwen] APIKey is required") } if len(config.Model) == 0 { return nil, fmt.Errorf("[Qwen] Model is required") } client := buildQwenClient(config) return &QwenEmbedder{ client: client, conf: config, }, nil } func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ( [][]float64, error) { if len(texts) == 0 { return nil, fmt.Errorf("[Qwen] texts cannot be empty") } options := embedding.GetCommonOptions(&embedding.Options{ Model: &e.conf.Model, }, opts...) conf := &embedding.Config{ Model: dereferenceOrZero(options.Model), } ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding) ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{ Texts: texts, Config: conf, }) defer func() { if err := recover(); err != nil { callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err)) } }() var usage *embedding.TokenUsage var embeddings [][]float64 var err error // 调用千问API获取embedding embeddings, usage, err = e.callEmbeddingAPI(ctx, texts) if err != nil { callbacks.OnError(ctx, err) return nil, err } callbacks.OnEnd(ctx, &embedding.CallbackOutput{ Embeddings: embeddings, Config: conf, TokenUsage: usage, }) return embeddings, nil } func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) { // 构建请求 var req EmbeddingRequest req.Model = e.conf.Model req.Input.Texts = texts req.Parameters.TextType = e.conf.TextType // 调用API client := e.client.Clone() client.SetHeader("Authorization", "Bearer "+e.conf.APIKey) client.SetHeader("Content-Type", "application/json") client.SetTimeout(*e.conf.Timeout) resp, err := client.Post(ctx, e.conf.BaseURL, req) if err != nil { return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err) } defer resp.Close() // 检查状态码 if resp.StatusCode != http.StatusOK { var errResp APIError result := resp.ReadAll() if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" { return nil, nil, &errResp } return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode) } // 解析响应 var apiResp EmbeddingResponse result := resp.ReadAll() if err = gconv.Struct(result, &apiResp); err != nil { return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err) } // 解析响应结果 embeddings := make([][]float64, len(texts)) for _, emb := range apiResp.Output.Embeddings { if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) { embeddings[emb.TextIndex] = emb.Embedding } } usage := &embedding.TokenUsage{ TotalTokens: apiResp.Usage.TotalTokens, } g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens) return embeddings, usage, nil } func (e *QwenEmbedder) GetType() string { return getType() } func (e *QwenEmbedder) IsCallbacksEnabled() bool { return true } func getType() string { return "Qwen" } func dereferenceOrZero[T any](v *T) T { if v == nil { var t T return t } return *v }