refactor: 重构文档处理流程和任务管理

This commit is contained in:
2026-04-09 09:11:43 +08:00
parent b6896f3fb4
commit 7f894745e9
34 changed files with 1216 additions and 1056 deletions

View File

@@ -1,166 +0,0 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-ext/components/model/ark"
)
func main() {
ctx := context.Background()
// ==========================================
// 1. 初始化三大组件
// ==========================================
// 1.1 向量检索(从知识库查客服知识)
ragRetriever := NewPGVectorRetriever()
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
chatTpl := newCustomerServiceTemplate()
// 1.3 大模型ARK
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL_ID"),
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 2. 模拟会话:从 DB 读取历史对话
// ==========================================
sessionHistory := []*schema.Message{
{Role: schema.User, Content: "你们发什么快递?"},
{Role: schema.Assistant, Content: "默认发中通快递"},
{Role: schema.User, Content: "可以发顺丰吗?"},
}
// 当前用户问题
userQuery := "那顺丰需要加钱吗?"
// ==========================================
// 3. RAG 检索知识库
// ==========================================
docs, err := ragRetriever.Retrieve(ctx, userQuery)
if err != nil {
log.Fatal(err)
}
// 拼接参考知识
knowledge := ""
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
// ==========================================
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
// ==========================================
msgs, err := chatTpl.Format(ctx, map[string]any{
"history": sessionHistory,
"knowledge": knowledge,
"question": userQuery,
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 5. 流式调用大模型生成客服回答
// ==========================================
fmt.Println("\n=== 客服回复 ===")
stream, err := chatModel.Stream(ctx, msgs)
if err != nil {
log.Fatal(err)
}
fullReply := make([]*schema.Message, 0, 100)
for {
chunk, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Print(chunk.Content)
fullReply = append(fullReply, chunk)
}
// ==========================================
// 6. 拼接完整回复,存入 DB 作为新历史
// ==========================================
replyMsg, _ := schema.ConcatMessages(fullReply)
sessionHistory = append(sessionHistory,
&schema.Message{Role: schema.User, Content: userQuery},
replyMsg,
)
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
}
// ==========================================
// 本地客服提示词模板(不需要 MCP
// ==========================================
func newCustomerServiceTemplate() prompt.ChatTemplate {
// 系统提示 + 多轮对话 + 知识库 + 用户问题
return prompt.FromMessages(schema.Messages{
{
Role: schema.System,
Content: `你是电商智能客服,语气友好简洁。
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
参考知识:
{{.knowledge}}`,
},
// 历史对话会自动渲染在这里
{{range .history}}{{.}},{{end}},
// 当前用户问题
{Role: schema.User, Content: "{{.question}}"},
})
}
// ==========================================
// PGVector 检索器(简化可直接用)
// ==========================================
type PGVectorRetriever struct {
topK int
}
func NewPGVectorRetriever() retriever.Retriever {
return &PGVectorRetriever{topK: 3}
}
func (r *PGVectorRetriever) Retrieve(
ctx context.Context,
query string,
opts ...retriever.Option,
) ([]*schema.Document, error) {
options := retriever.GetCommonOptions(nil, opts...)
topK := r.topK
if options.TopK != nil {
topK = *options.TopK
}
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
// 模拟知识库
return []*schema.Document{
{
ID: "1",
Content: "顺丰快递需要补10元运费差价",
},
{
ID: "2",
Content: "订单满99元可免费升级顺丰",
},
}, nil
}

View File

@@ -1,107 +0,0 @@
package eino
import (
"context"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/cloudwego/eino-ext/components/indexer/es8"
)
const (
indexName = "eino_example"
fieldContent = "content"
fieldContentVector = "content_vector"
fieldExtraLocation = "location"
docExtraLocation = "location"
)
func TestIndexer() {
ctx := context.Background()
// 1. 创建 ES 客户端
client, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
if err != nil {
fmt.Printf("create client error: %v\n", err)
return
}
// 2. 定义 Index Spec选填如果索引不存在将自动创建
indexSpec := &es8.IndexSpec{
Settings: map[string]any{
"number_of_shards": 1,
"number_of_replicas": 0,
},
Mappings: map[string]any{
"properties": map[string]any{
fieldContentVector: map[string]any{
"type": "dense_vector",
"dims": 1024,
"index": true,
"similarity": "l2_norm",
},
},
},
}
// 4. 准备文档
// 文档通常包含 ID 和 Content
// 也可以包含额外的 Metadata 用于过滤或其他用途
docs := []*schema.Document{
{
ID: "1",
Content: "Eiffel Tower: Located in Paris, France.",
MetaData: map[string]any{
docExtraLocation: "France",
},
},
{
ID: "2",
Content: "The Great Wall: Located in China.",
MetaData: map[string]any{
docExtraLocation: "China",
},
},
}
// 5. 创建 ES 索引器组件
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
Client: client,
Index: indexName,
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
BatchSize: 10,
// DocumentToFields 指定如何将文档字段映射到 ES 字段
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
return map[string]es8.FieldValue{
fieldContent: {
Value: doc.Content,
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
},
fieldExtraLocation: {
// 额外的 metadata 字段
Value: doc.MetaData[docExtraLocation],
},
}, nil
},
// 提供 embedding 组件用于向量化
Embedding: EmbedderDashscope,
})
if err != nil {
fmt.Printf("create indexer error: %v\n", err)
return
}
// 6. 索引文档
ids, err := indexer.Store(ctx, docs)
if err != nil {
fmt.Printf("index error: %v\n", err)
return
}
fmt.Println("indexed ids:", ids)
}

View File

@@ -1,49 +0,0 @@
package eino
import (
"time"
"gitea.com/red-future/common/beans"
)
// BaseTask 任务基类 - MongoDB版本
type BaseTask struct {
beans.MongoBaseDO `bson:",inline"`
// 任务信息
TaskType TaskType `bson:"taskType" json:"taskType"`
Status TaskStatus `bson:"status" json:"status"`
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
// 进度
TotalItems int64 `bson:"totalItems" json:"totalItems"`
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
Progress float64 `bson:"progress" json:"progress"`
// 结果
StartTime *time.Time `bson:"startTime" json:"startTime"`
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
SuccessCount int64 `bson:"successCount" json:"successCount"`
FailCount int64 `bson:"failCount" json:"failCount"`
// 其他
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
}
// SQLBaseTask 任务基类 - SQL版本
type SQLBaseTask struct {
beans.SQLBaseDO
// 任务信息
TaskType TaskType `json:"taskType"`
Status TaskStatus `json:"status"`
Priority TaskPriority `json:"priority,omitempty"`
// 进度
TotalItems int64 `json:"totalItems"`
ProcessedItems int64 `json:"processedItems"`
Progress float64 `json:"progress"`
// 结果
StartTime *time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration int64 `json:"duration,omitempty"`
SuccessCount int64 `json:"successCount"`
FailCount int64 `json:"failCount"`
// 其他
Executor string `json:"executor,omitempty"`
}

View File

@@ -1,94 +0,0 @@
package eino
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
)
func TestRetriever() {
ctx := context.Background()
client, _ := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
// 创建 retriever 组件
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
Client: client,
Index: indexName,
TopK: 5,
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
QueryFieldName: fieldContent,
VectorFieldName: fieldContentVector,
Hybrid: false,
// RRF 仅在特定许可证下可用
// 参见: https://www.elastic.co/subscriptions
RRF: false,
RRFRankConstant: nil,
RRFWindowSize: nil,
}),
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
doc = &schema.Document{
ID: *hit.Id_,
Content: "",
MetaData: map[string]any{},
}
var src map[string]any
if err = json.Unmarshal(hit.Source_, &src); err != nil {
return nil, err
}
for field, val := range src {
switch field {
case fieldContent:
doc.Content = val.(string)
case fieldContentVector:
var v []float64
for _, item := range val.([]interface{}) {
v = append(v, item.(float64))
}
doc.WithDenseVector(v)
case fieldExtraLocation:
doc.MetaData[docExtraLocation] = val.(string)
}
}
if hit.Score_ != nil {
doc.WithScore(float64(*hit.Score_))
}
return doc, nil
},
Embedding: EmbedderDashscope,
})
// 不带过滤器的搜索
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
// 带过滤器的搜索
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
es8.WithFilters([]types.Query{{
Term: map[string]types.TermQuery{
fieldExtraLocation: {
CaseInsensitive: of(true),
Value: "China",
},
},
}}),
)
fmt.Printf("retrieved docs: %+v\n", docs)
}
func of[T any](v T) *T {
return &v
}

125
common/eino/chat_model.go Normal file
View File

@@ -0,0 +1,125 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
var globalChatModel *qwen.ChatModel
func init() {
ctx := context.Background()
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
var err error
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: apiKey,
Model: model,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
MaxTokens: gconv.PtrInt(1024), // 最长回答
TopP: gconv.PtrFloat32(1.0),
})
if err != nil {
glog.Errorf(ctx, "初始化大模型失败: %v", err)
}
return
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
// 1. 构建参考知识
knowledge, sources := buildKnowledgeAndSources(docs)
// 2. 构建提示词
msgs, err := buildPromptMessages(ctx, knowledge, content)
if err != nil {
return
}
// 3. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
return
}
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
var knowledge string
var sources []string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
// 提取 document_id
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
sources = append(sources, gconv.String(docID))
}
}
return knowledge, sources
}
// buildPromptMessages 构建提示词模板
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
promptTpl := prompt.FromMessages(
schema.FString,
&schema.Message{
Role: schema.System,
// Content: `你是专业的客服助手,语气友好。
//如果参考知识中有相关信息,请优先依据参考知识回答。
//如果没有相关信息,就正常回答,不要说无法回答。
//
//参考知识:
//{knowledge}`,
Content: `你是专业的客服助手,语气友好。
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
},
&schema.Message{
Role: schema.User,
Content: "{question}",
},
)
return promptTpl.Format(ctx, map[string]any{
"knowledge": knowledge,
"question": question,
})
}
// streamGenerateAnswer 流式生成
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
sr, err := chatModel.Stream(ctx, msgs)
if err != nil {
return nil, fmt.Errorf("stream failed: %w", err)
}
var chunks []*schema.Message
for {
chunk, err := sr.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("stream recv failed: %w", err)
}
chunks = append(chunks, chunk)
}
return schema.ConcatMessages(chunks)
}

View File

@@ -1,273 +0,0 @@
/*
* Copyright 2024 Red Future Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eino
import (
"context"
"fmt"
"net/http"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gclient"
"github.com/gogf/gf/v2/util/gconv"
)
var (
// 千问API默认配置
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
defaultTimeout = 10 * time.Minute
defaultRetryTimes = 2
)
type QwenEmbeddingConfig struct {
// Timeout specifies the maximum duration to wait for API responses
// Optional. Default: 10 minutes
Timeout *time.Duration `json:"timeout"`
// HTTPClient specifies the client to send HTTP requests.
// Optional. Default &http.Client{Timeout: Timeout}
HTTPClient *http.Client `json:"http_client"`
// RetryTimes specifies the number of retry attempts for failed API calls
// Optional. Default: 2
RetryTimes *int `json:"retry_times"`
// BaseURL specifies the base URL for Qwen DashScope service
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
BaseURL string `json:"base_url"`
// APIKey specifies the API Key for authentication
// Required
APIKey string `json:"api_key"`
// Model specifies the model name for Qwen embedding
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
Model string `json:"model"`
// TextType specifies the type of text: "document" or "query"
// Optional. Default: "document"
TextType string `json:"text_type"`
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
// Optional. Default: 5
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
}
type QwenEmbedder struct {
client *gclient.Client
conf *QwenEmbeddingConfig
}
// EmbeddingRequest 千问embedding请求结构
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
// EmbeddingResponse 千问embedding响应结构
type EmbeddingResponse struct {
Output struct {
Embeddings []struct {
TextIndex int `json:"text_index"`
Embedding []float64 `json:"embedding"`
} `json:"embeddings"`
} `json:"output"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
RequestID string `json:"request_id"`
}
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
func (e *APIError) Error() string {
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
}
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
if len(config.BaseURL) == 0 {
config.BaseURL = defaultBaseURL
}
if config.Timeout == nil {
config.Timeout = &defaultTimeout
}
if config.RetryTimes == nil {
defaultRetryTimes := 2
config.RetryTimes = &defaultRetryTimes
}
if len(config.TextType) == 0 {
config.TextType = "document"
}
if config.MaxConcurrentRequests == nil {
defaultMaxConcurrentRequests := 5
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
}
client := g.Client()
client.SetTimeout(*config.Timeout)
return client
}
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
if len(config.APIKey) == 0 {
return nil, fmt.Errorf("[Qwen] APIKey is required")
}
if len(config.Model) == 0 {
return nil, fmt.Errorf("[Qwen] Model is required")
}
client := buildQwenClient(config)
return &QwenEmbedder{
client: client,
conf: config,
}, nil
}
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
[][]float64, error) {
if len(texts) == 0 {
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
}
options := embedding.GetCommonOptions(&embedding.Options{
Model: &e.conf.Model,
}, opts...)
conf := &embedding.Config{
Model: dereferenceOrZero(options.Model),
}
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
Texts: texts,
Config: conf,
})
defer func() {
if err := recover(); err != nil {
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
}
}()
var usage *embedding.TokenUsage
var embeddings [][]float64
var err error
// 调用千问API获取embedding
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
Embeddings: embeddings,
Config: conf,
TokenUsage: usage,
})
return embeddings, nil
}
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
// 构建请求
var req EmbeddingRequest
req.Model = e.conf.Model
req.Input.Texts = texts
req.Parameters.TextType = e.conf.TextType
// 调用API
client := e.client.Clone()
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
client.SetHeader("Content-Type", "application/json")
client.SetTimeout(*e.conf.Timeout)
resp, err := client.Post(ctx, e.conf.BaseURL, req)
if err != nil {
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
}
defer resp.Close()
// 检查状态码
if resp.StatusCode != http.StatusOK {
var errResp APIError
result := resp.ReadAll()
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
return nil, nil, &errResp
}
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
}
// 解析响应
var apiResp EmbeddingResponse
result := resp.ReadAll()
if err = gconv.Struct(result, &apiResp); err != nil {
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
}
// 解析响应结果
embeddings := make([][]float64, len(texts))
for _, emb := range apiResp.Output.Embeddings {
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
embeddings[emb.TextIndex] = emb.Embedding
}
}
usage := &embedding.TokenUsage{
TotalTokens: apiResp.Usage.TotalTokens,
}
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
return embeddings, usage, nil
}
func (e *QwenEmbedder) GetType() string {
return getType()
}
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
return true
}
func getType() string {
return "Qwen"
}
func dereferenceOrZero[T any](v *T) T {
if v == nil {
var t T
return t
}
return *v
}

View File

@@ -1,11 +0,0 @@
package eino
// TaskPriority 任务优先级
type TaskPriority string
const (
TaskPriorityLow TaskPriority = "low" // 低优先级
TaskPriorityMedium TaskPriority = "medium" // 中优先级
TaskPriorityHigh TaskPriority = "high" // 高优先级
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
)

View File

@@ -3,6 +3,8 @@ package eino
import (
"context"
"errors"
"rag/dao"
"sort"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
@@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
DSLInfo map[string]any
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
dslInfo map[string]any
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
@@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever,
embedder: config.Embedder,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
// 1. 处理公共 Option官方标准写法
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
DSLInfo: r.dslInfo,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
// 2. 回调(官方标准)
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// 3. 执行检索
docs, err := r.doRetrieve(ctx, query, options)
// ==========================================
// 🔥 双路检索:向量 + 全文
// ==========================================
docsVector, err := r.doRetrieveVector(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 4. 完成回调
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
Docs: docs,
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 合并 + 去重
docs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序distance 越小越靠前)
sort.Slice(docs, func(i, j int) bool {
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 最多保留 topK
if len(docs) > *options.TopK {
docs = docs[:*options.TopK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
return docs, nil
}
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
// 1. 生成向量
// ==========================================
// 1. 向量检索PG
// ==========================================
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
@@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(vectors[0])
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
// 2. PG 向量相似度检索 SQL
sql := `
SELECT id, content, dataset_id, document_id,
vector <-> ? AS distance
FROM document_chunk
ORDER BY distance ASC
LIMIT ?
`
// 3. 查询
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
if err != nil {
return nil, err
}
// 4. 转为 Eino Document
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
Metadata: map[string]any{
"dataset_id": row["dataset_id"],
"document_id": row["document_id"],
"distance": row["distance"],
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": gconv.Float64(row["distance"]),
"retrieve_by": "vector",
},
})
}
return docs, nil
}
// ==========================================
// 2. 全文检索Meilisearch🔥 新增
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": 0.1, // 全文结果给高分
"retrieve_by": "fulltext",
},
})
}
return docs, nil
}
// ==========================================
// 合并去重
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
for _, d := range vecDocs {
idMap[d.ID] = d
}
for _, d := range fullDocs {
if _, exists := idMap[d.ID]; !exists {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)
}
return merged
}

View File

@@ -1,12 +0,0 @@
package eino
// TaskStatus 任务状态
type TaskStatus string
const (
TaskStatusPending TaskStatus = "pending" // 待处理
TaskStatusRunning TaskStatus = "running" // 运行中
TaskStatusCompleted TaskStatus = "completed" // 已完成
TaskStatusFailed TaskStatus = "failed" // 失败
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
)

View File

@@ -1,14 +0,0 @@
package eino
// TaskType 任务类型
type TaskType string
const (
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
)

View File

@@ -1,114 +0,0 @@
package gse
import (
"context"
"sort"
"github.com/go-ego/gse"
"github.com/go-ego/gse/hmm/extracker"
"github.com/go-ego/gse/hmm/segment"
"github.com/gogf/gf/v2/os/glog"
)
var GseTool *gseTool
// 初始化函数:程序启动时执行一次
func init() {
var err error
GseTool, err = newGseTool()
if err != nil {
glog.Error(context.Background(), err)
}
}
// gseTool 关键词提取工具gse v1.0.2 标准)
type gseTool struct {
seg gse.Segmenter
tfidf *extracker.TagExtracter
tr *extracker.TextRanker
}
// newGseTool 初始化工具(内置词典 + 停用词)
func newGseTool() (tool *gseTool, err error) {
// 1. 初始化分词器
var seg gse.Segmenter
// 内置词典(无外部文件)
err = seg.LoadDictEmbed()
if err != nil {
return
}
// 内置停用词v1.0.2 标准)
err = seg.LoadStopEmbed()
if err != nil {
return
}
// 2. 初始化 TF-IDF 提取器
tfidf := &extracker.TagExtracter{}
tfidf.WithGse(seg)
err = tfidf.LoadIdf()
if err != nil {
return
}
// 3. 初始化 TextRank 提取器
tr := &extracker.TextRanker{}
tr.WithGse(seg)
tool = &gseTool{
seg: seg,
tfidf: tfidf,
tr: tr,
}
return
}
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM
func (k *gseTool) Cut(text string) []string {
return k.seg.Cut(text, true)
}
// Keyword 最终输出:关键词 + 权重
type Keyword struct {
Word string `json:"word"`
Score float64 `json:"score"`
}
func (k *gseTool) Extract(text string, topN int) []Keyword {
// 1. 提取 TF-IDF
tfTags := k.extractTFIDF(text, topN)
// 2. 提取 TextRank
trTags := k.extractTextRank(text, topN)
// 3. 合并成最终关键词(业务最常用)
scoreMap := make(map[string]float64)
for _, tag := range tfTags {
scoreMap[tag.Text] = tag.Weight
}
for _, tag := range trTags {
scoreMap[tag.Text] = tag.Weight
}
// 转成切片并排序(高分在前)
res := make([]Keyword, 0, len(scoreMap))
for word, score := range scoreMap {
res = append(res, Keyword{Word: word, Score: score})
}
sort.Slice(res, func(i, j int) bool {
return res[i].Score > res[j].Score
})
return res
}
// ExtractTFIDF TF-IDF 关键词带权重90% 业务:文章标签、搜索、关键词
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
return k.tfidf.ExtractTags(text, topN)
}
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
return k.tr.TextRank(text, topN)
}

69
common/task/base_task.go Normal file
View File

@@ -0,0 +1,69 @@
package task
import (
"time"
"gitea.com/red-future/common/beans"
)
type baseTaskCol struct {
beans.SQLBaseCol
TaskType string
Status string
Priority string
ParentTaskID string
TotalItems string
ProcessedItems string
Progress string
StartTime string
EndTime string
Duration string
SuccessCount string
FailCount string
Executor string
DocumentID string
Remark string
}
var BaseTaskCol = baseTaskCol{
SQLBaseCol: beans.DefSQLBaseCol,
TaskType: "task_type",
Status: "status",
Priority: "task_priority",
ParentTaskID: "parent_task_id",
TotalItems: "total_items",
ProcessedItems: "processed_items",
Progress: "progress",
StartTime: "start_time",
EndTime: "end_time",
Duration: "duration",
SuccessCount: "success_count",
FailCount: "fail_count",
Executor: "executor",
DocumentID: "document_id",
Remark: "remark",
}
// SQLBaseTask 任务基类 - SQL版本
type SQLBaseTask struct {
beans.SQLBaseDO `orm:",inline"`
// 任务核心信息
TaskType TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
Status TaskStatus `orm:"status" json:"status" dc:"任务状态"`
Priority TaskPriority `orm:"task_priority" json:"priority,omitempty" dc:"任务优先级"`
ParentTaskID int64 `orm:"parent_task_id" json:"parentTaskId,omitempty" dc:"父任务ID"`
// 任务进度
TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总数"`
ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理数"`
Progress float64 `orm:"progress" json:"progress" dc:"进度"` // 0~100 百分比
// 任务结果
StartTime *time.Time `orm:"start_time" json:"startTime" dc:"开始时间"`
EndTime *time.Time `orm:"end_time" json:"endTime,omitempty" dc:"结束时间"`
Duration int64 `orm:"duration" json:"duration,omitempty" dc:"耗时(毫秒)"`
SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
// 其他
Executor string `orm:"executor" json:"executor,omitempty" dc:"执行器标识"`
DocumentID int64 `orm:"document_id" json:"documentId,omitempty" dc:"文档ID"`
Remark string `orm:"remark" json:"remark,omitempty" dc:"备注/错误信息"`
}

30
common/task/consts.go Normal file
View File

@@ -0,0 +1,30 @@
package task
// TaskType 任务类型枚举:文档解析的三个子任务
type TaskType string
const (
TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词
TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量
TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索
TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务
)
// TaskStatus 任务状态枚举
type TaskStatus string
const (
TaskStatusPending TaskStatus = "PENDING" // 待执行
TaskStatusRunning TaskStatus = "RUNNING" // 执行中
TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成
TaskStatusFailed TaskStatus = "FAILED" // 执行失败
)
// TaskPriority 任务优先级
type TaskPriority int
const (
TaskPriorityLow TaskPriority = 1 // 低
TaskPriorityMedium TaskPriority = 2 // 中
TaskPriorityHigh TaskPriority = 3 // 高
)