From 9445a42eabf97083c2135c86c5e24909510b90c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=96=8C?= <259278618@qq.com> Date: Fri, 30 Jan 2026 17:31:54 +0800 Subject: [PATCH] .gitignore --- rag/{enio => eino}/base_task.go | 2 +- rag/eino/embedding_qwen.go | 273 ++++++++++++++++++++++++++++ rag/{enio => eino}/priority_enum.go | 2 +- rag/{enio => eino}/status_enum.go | 2 +- rag/{enio => eino}/task_type.go | 2 +- 5 files changed, 277 insertions(+), 4 deletions(-) rename rag/{enio => eino}/base_task.go (98%) create mode 100644 rag/eino/embedding_qwen.go rename rag/{enio => eino}/priority_enum.go (95%) rename rag/{enio => eino}/status_enum.go (96%) rename rag/{enio => eino}/task_type.go (98%) diff --git a/rag/enio/base_task.go b/rag/eino/base_task.go similarity index 98% rename from rag/enio/base_task.go rename to rag/eino/base_task.go index 866d5e2..06d6a86 100644 --- a/rag/enio/base_task.go +++ b/rag/eino/base_task.go @@ -1,4 +1,4 @@ -package enio +package eino import ( "time" diff --git a/rag/eino/embedding_qwen.go b/rag/eino/embedding_qwen.go new file mode 100644 index 0000000..9496874 --- /dev/null +++ b/rag/eino/embedding_qwen.go @@ -0,0 +1,273 @@ +/* + * Copyright 2024 Red Future Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package eino + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/cloudwego/eino/callbacks" + "github.com/cloudwego/eino/components" + "github.com/cloudwego/eino/components/embedding" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/gclient" + "github.com/gogf/gf/v2/util/gconv" +) + +var ( + // 千问API默认配置 + defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" + defaultTimeout = 10 * time.Minute + defaultRetryTimes = 2 +) + +type QwenEmbeddingConfig struct { + // Timeout specifies the maximum duration to wait for API responses + // Optional. Default: 10 minutes + Timeout *time.Duration `json:"timeout"` + + // HTTPClient specifies the client to send HTTP requests. + // Optional. Default &http.Client{Timeout: Timeout} + HTTPClient *http.Client `json:"http_client"` + + // RetryTimes specifies the number of retry attempts for failed API calls + // Optional. Default: 2 + RetryTimes *int `json:"retry_times"` + + // BaseURL specifies the base URL for Qwen DashScope service + // Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" + BaseURL string `json:"base_url"` + + // APIKey specifies the API Key for authentication + // Required + APIKey string `json:"api_key"` + + // Model specifies the model name for Qwen embedding + // Required. Examples: "text-embedding-v2", "text-embedding-v3" + Model string `json:"model"` + + // TextType specifies the type of text: "document" or "query" + // Optional. Default: "document" + TextType string `json:"text_type"` + + // MaxConcurrentRequests specifies the maximum number of concurrent requests allowed + // Optional. Default: 5 + MaxConcurrentRequests *int `json:"max_concurrent_requests"` +} + +type QwenEmbedder struct { + client *gclient.Client + conf *QwenEmbeddingConfig +} + +// EmbeddingRequest 千问embedding请求结构 +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +// EmbeddingResponse 千问embedding响应结构 +type EmbeddingResponse struct { + Output struct { + Embeddings []struct { + TextIndex int `json:"text_index"` + Embedding []float64 `json:"embedding"` + } `json:"embeddings"` + } `json:"output"` + Usage struct { + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + RequestID string `json:"request_id"` +} + +type APIError struct { + Code string `json:"code"` + Message string `json:"message"` + RequestID string `json:"request_id"` +} + +func (e *APIError) Error() string { + return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID) +} + +func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client { + if len(config.BaseURL) == 0 { + config.BaseURL = defaultBaseURL + } + if config.Timeout == nil { + config.Timeout = &defaultTimeout + } + if config.RetryTimes == nil { + defaultRetryTimes := 2 + config.RetryTimes = &defaultRetryTimes + } + if len(config.TextType) == 0 { + config.TextType = "document" + } + if config.MaxConcurrentRequests == nil { + defaultMaxConcurrentRequests := 5 + config.MaxConcurrentRequests = &defaultMaxConcurrentRequests + } + + client := g.Client() + client.SetTimeout(*config.Timeout) + + return client +} + +func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) { + if len(config.APIKey) == 0 { + return nil, fmt.Errorf("[Qwen] APIKey is required") + } + if len(config.Model) == 0 { + return nil, fmt.Errorf("[Qwen] Model is required") + } + + client := buildQwenClient(config) + + return &QwenEmbedder{ + client: client, + conf: config, + }, nil +} + +func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ( + [][]float64, error) { + + if len(texts) == 0 { + return nil, fmt.Errorf("[Qwen] texts cannot be empty") + } + + options := embedding.GetCommonOptions(&embedding.Options{ + Model: &e.conf.Model, + }, opts...) + + conf := &embedding.Config{ + Model: dereferenceOrZero(options.Model), + } + + ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding) + ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{ + Texts: texts, + Config: conf, + }) + defer func() { + if err := recover(); err != nil { + callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err)) + } + }() + + var usage *embedding.TokenUsage + var embeddings [][]float64 + var err error + + // 调用千问API获取embedding + embeddings, usage, err = e.callEmbeddingAPI(ctx, texts) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + callbacks.OnEnd(ctx, &embedding.CallbackOutput{ + Embeddings: embeddings, + Config: conf, + TokenUsage: usage, + }) + + return embeddings, nil +} + +func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) { + // 构建请求 + var req EmbeddingRequest + req.Model = e.conf.Model + req.Input.Texts = texts + req.Parameters.TextType = e.conf.TextType + + // 调用API + client := e.client.Clone() + client.SetHeader("Authorization", "Bearer "+e.conf.APIKey) + client.SetHeader("Content-Type", "application/json") + client.SetTimeout(*e.conf.Timeout) + + resp, err := client.Post(ctx, e.conf.BaseURL, req) + if err != nil { + return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err) + } + + defer resp.Close() + + // 检查状态码 + if resp.StatusCode != http.StatusOK { + var errResp APIError + result := resp.ReadAll() + if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" { + return nil, nil, &errResp + } + return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode) + } + + // 解析响应 + var apiResp EmbeddingResponse + result := resp.ReadAll() + if err = gconv.Struct(result, &apiResp); err != nil { + return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err) + } + + // 解析响应结果 + embeddings := make([][]float64, len(texts)) + for _, emb := range apiResp.Output.Embeddings { + if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) { + embeddings[emb.TextIndex] = emb.Embedding + } + } + + usage := &embedding.TokenUsage{ + TotalTokens: apiResp.Usage.TotalTokens, + } + + g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens) + + return embeddings, usage, nil +} + +func (e *QwenEmbedder) GetType() string { + return getType() +} + +func (e *QwenEmbedder) IsCallbacksEnabled() bool { + return true +} + +func getType() string { + return "Qwen" +} + +func dereferenceOrZero[T any](v *T) T { + if v == nil { + var t T + return t + } + return *v +} diff --git a/rag/enio/priority_enum.go b/rag/eino/priority_enum.go similarity index 95% rename from rag/enio/priority_enum.go rename to rag/eino/priority_enum.go index 365903a..371706b 100644 --- a/rag/enio/priority_enum.go +++ b/rag/eino/priority_enum.go @@ -1,4 +1,4 @@ -package enio +package eino // TaskPriority 任务优先级 type TaskPriority string diff --git a/rag/enio/status_enum.go b/rag/eino/status_enum.go similarity index 96% rename from rag/enio/status_enum.go rename to rag/eino/status_enum.go index d6d2479..6e12daf 100644 --- a/rag/enio/status_enum.go +++ b/rag/eino/status_enum.go @@ -1,4 +1,4 @@ -package enio +package eino // TaskStatus 任务状态 type TaskStatus string diff --git a/rag/enio/task_type.go b/rag/eino/task_type.go similarity index 98% rename from rag/enio/task_type.go rename to rag/eino/task_type.go index 4dec33f..0ba5a64 100644 --- a/rag/enio/task_type.go +++ b/rag/eino/task_type.go @@ -1,4 +1,4 @@ -package enio +package eino // TaskType 任务类型 type TaskType string