refactor: 重构文档处理流程和任务管理
This commit is contained in:
166
common/eino/a.go
166
common/eino/a.go
@@ -1,166 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// ==========================================
|
||||
// 1. 初始化三大组件
|
||||
// ==========================================
|
||||
// 1.1 向量检索(从知识库查客服知识)
|
||||
ragRetriever := NewPGVectorRetriever()
|
||||
|
||||
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
|
||||
chatTpl := newCustomerServiceTemplate()
|
||||
|
||||
// 1.3 大模型(ARK)
|
||||
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||||
APIKey: os.Getenv("ARK_API_KEY"),
|
||||
Model: os.Getenv("ARK_MODEL_ID"),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 模拟会话:从 DB 读取历史对话
|
||||
// ==========================================
|
||||
sessionHistory := []*schema.Message{
|
||||
{Role: schema.User, Content: "你们发什么快递?"},
|
||||
{Role: schema.Assistant, Content: "默认发中通快递"},
|
||||
{Role: schema.User, Content: "可以发顺丰吗?"},
|
||||
}
|
||||
|
||||
// 当前用户问题
|
||||
userQuery := "那顺丰需要加钱吗?"
|
||||
|
||||
// ==========================================
|
||||
// 3. RAG 检索知识库
|
||||
// ==========================================
|
||||
docs, err := ragRetriever.Retrieve(ctx, userQuery)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 拼接参考知识
|
||||
knowledge := ""
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
|
||||
// ==========================================
|
||||
msgs, err := chatTpl.Format(ctx, map[string]any{
|
||||
"history": sessionHistory,
|
||||
"knowledge": knowledge,
|
||||
"question": userQuery,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 5. 流式调用大模型生成客服回答
|
||||
// ==========================================
|
||||
fmt.Println("\n=== 客服回复 ===")
|
||||
stream, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fullReply := make([]*schema.Message, 0, 100)
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Print(chunk.Content)
|
||||
fullReply = append(fullReply, chunk)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 6. 拼接完整回复,存入 DB 作为新历史
|
||||
// ==========================================
|
||||
replyMsg, _ := schema.ConcatMessages(fullReply)
|
||||
sessionHistory = append(sessionHistory,
|
||||
&schema.Message{Role: schema.User, Content: userQuery},
|
||||
replyMsg,
|
||||
)
|
||||
|
||||
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 本地客服提示词模板(不需要 MCP)
|
||||
// ==========================================
|
||||
func newCustomerServiceTemplate() prompt.ChatTemplate {
|
||||
// 系统提示 + 多轮对话 + 知识库 + 用户问题
|
||||
return prompt.FromMessages(schema.Messages{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: `你是电商智能客服,语气友好简洁。
|
||||
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
|
||||
|
||||
参考知识:
|
||||
{{.knowledge}}`,
|
||||
},
|
||||
// 历史对话会自动渲染在这里
|
||||
{{range .history}}{{.}},{{end}},
|
||||
// 当前用户问题
|
||||
{Role: schema.User, Content: "{{.question}}"},
|
||||
})
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// PGVector 检索器(简化可直接用)
|
||||
// ==========================================
|
||||
type PGVectorRetriever struct {
|
||||
topK int
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever() retriever.Retriever {
|
||||
return &PGVectorRetriever{topK: 3}
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
opts ...retriever.Option,
|
||||
) ([]*schema.Document, error) {
|
||||
|
||||
options := retriever.GetCommonOptions(nil, opts...)
|
||||
topK := r.topK
|
||||
if options.TopK != nil {
|
||||
topK = *options.TopK
|
||||
}
|
||||
|
||||
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
|
||||
// 模拟知识库
|
||||
return []*schema.Document{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "顺丰快递需要补10元运费差价",
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "订单满99元可免费升级顺丰",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
107
common/eino/b.go
107
common/eino/b.go
@@ -1,107 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/indexer/es8"
|
||||
)
|
||||
|
||||
const (
|
||||
indexName = "eino_example"
|
||||
fieldContent = "content"
|
||||
fieldContentVector = "content_vector"
|
||||
fieldExtraLocation = "location"
|
||||
docExtraLocation = "location"
|
||||
)
|
||||
|
||||
func TestIndexer() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. 创建 ES 客户端
|
||||
client, err := elasticsearch.NewClient(elasticsearch.Config{
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("create client error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 定义 Index Spec(选填:如果索引不存在,将自动创建)
|
||||
indexSpec := &es8.IndexSpec{
|
||||
Settings: map[string]any{
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
},
|
||||
Mappings: map[string]any{
|
||||
"properties": map[string]any{
|
||||
fieldContentVector: map[string]any{
|
||||
"type": "dense_vector",
|
||||
"dims": 1024,
|
||||
"index": true,
|
||||
"similarity": "l2_norm",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 4. 准备文档
|
||||
// 文档通常包含 ID 和 Content
|
||||
// 也可以包含额外的 Metadata 用于过滤或其他用途
|
||||
docs := []*schema.Document{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "Eiffel Tower: Located in Paris, France.",
|
||||
MetaData: map[string]any{
|
||||
docExtraLocation: "France",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "The Great Wall: Located in China.",
|
||||
MetaData: map[string]any{
|
||||
docExtraLocation: "China",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 5. 创建 ES 索引器组件
|
||||
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
|
||||
Client: client,
|
||||
Index: indexName,
|
||||
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
|
||||
BatchSize: 10,
|
||||
// DocumentToFields 指定如何将文档字段映射到 ES 字段
|
||||
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
|
||||
return map[string]es8.FieldValue{
|
||||
fieldContent: {
|
||||
Value: doc.Content,
|
||||
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
|
||||
},
|
||||
fieldExtraLocation: {
|
||||
// 额外的 metadata 字段
|
||||
Value: doc.MetaData[docExtraLocation],
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
// 提供 embedding 组件用于向量化
|
||||
Embedding: EmbedderDashscope,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("create indexer error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 索引文档
|
||||
ids, err := indexer.Store(ctx, docs)
|
||||
if err != nil {
|
||||
fmt.Printf("index error: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("indexed ids:", ids)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
// BaseTask 任务基类 - MongoDB版本
|
||||
type BaseTask struct {
|
||||
beans.MongoBaseDO `bson:",inline"`
|
||||
// 任务信息
|
||||
TaskType TaskType `bson:"taskType" json:"taskType"`
|
||||
Status TaskStatus `bson:"status" json:"status"`
|
||||
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
|
||||
// 进度
|
||||
TotalItems int64 `bson:"totalItems" json:"totalItems"`
|
||||
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
|
||||
Progress float64 `bson:"progress" json:"progress"`
|
||||
// 结果
|
||||
StartTime *time.Time `bson:"startTime" json:"startTime"`
|
||||
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
|
||||
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
|
||||
SuccessCount int64 `bson:"successCount" json:"successCount"`
|
||||
FailCount int64 `bson:"failCount" json:"failCount"`
|
||||
// 其他
|
||||
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
|
||||
}
|
||||
|
||||
// SQLBaseTask 任务基类 - SQL版本
|
||||
type SQLBaseTask struct {
|
||||
beans.SQLBaseDO
|
||||
// 任务信息
|
||||
TaskType TaskType `json:"taskType"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Priority TaskPriority `json:"priority,omitempty"`
|
||||
// 进度
|
||||
TotalItems int64 `json:"totalItems"`
|
||||
ProcessedItems int64 `json:"processedItems"`
|
||||
Progress float64 `json:"progress"`
|
||||
// 结果
|
||||
StartTime *time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration int64 `json:"duration,omitempty"`
|
||||
SuccessCount int64 `json:"successCount"`
|
||||
FailCount int64 `json:"failCount"`
|
||||
// 其他
|
||||
Executor string `json:"executor,omitempty"`
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
||||
)
|
||||
|
||||
func TestRetriever() {
|
||||
ctx := context.Background()
|
||||
|
||||
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
})
|
||||
|
||||
// 创建 retriever 组件
|
||||
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
||||
Client: client,
|
||||
Index: indexName,
|
||||
TopK: 5,
|
||||
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
||||
QueryFieldName: fieldContent,
|
||||
VectorFieldName: fieldContentVector,
|
||||
Hybrid: false,
|
||||
// RRF 仅在特定许可证下可用
|
||||
// 参见: https://www.elastic.co/subscriptions
|
||||
RRF: false,
|
||||
RRFRankConstant: nil,
|
||||
RRFWindowSize: nil,
|
||||
}),
|
||||
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
||||
doc = &schema.Document{
|
||||
ID: *hit.Id_,
|
||||
Content: "",
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
|
||||
var src map[string]any
|
||||
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for field, val := range src {
|
||||
switch field {
|
||||
case fieldContent:
|
||||
doc.Content = val.(string)
|
||||
case fieldContentVector:
|
||||
var v []float64
|
||||
for _, item := range val.([]interface{}) {
|
||||
v = append(v, item.(float64))
|
||||
}
|
||||
doc.WithDenseVector(v)
|
||||
case fieldExtraLocation:
|
||||
doc.MetaData[docExtraLocation] = val.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if hit.Score_ != nil {
|
||||
doc.WithScore(float64(*hit.Score_))
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
},
|
||||
Embedding: EmbedderDashscope,
|
||||
})
|
||||
|
||||
// 不带过滤器的搜索
|
||||
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
||||
|
||||
// 带过滤器的搜索
|
||||
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
||||
es8.WithFilters([]types.Query{{
|
||||
Term: map[string]types.TermQuery{
|
||||
fieldExtraLocation: {
|
||||
CaseInsensitive: of(true),
|
||||
Value: "China",
|
||||
},
|
||||
},
|
||||
}}),
|
||||
)
|
||||
|
||||
fmt.Printf("retrieved docs: %+v\n", docs)
|
||||
}
|
||||
|
||||
func of[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
125
common/eino/chat_model.go
Normal file
125
common/eino/chat_model.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var globalChatModel *qwen.ChatModel
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
|
||||
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||
|
||||
var err error
|
||||
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NewChatModel 只处理逻辑,不复用创建模型
|
||||
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
|
||||
// 1. 构建参考知识
|
||||
knowledge, sources := buildKnowledgeAndSources(docs)
|
||||
|
||||
// 2. 构建提示词
|
||||
msgs, err := buildPromptMessages(ctx, knowledge, content)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 🔥 直接使用全局单例,不重复创建
|
||||
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
|
||||
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
|
||||
var knowledge string
|
||||
var sources []string
|
||||
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
|
||||
// 提取 document_id
|
||||
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
|
||||
sources = append(sources, gconv.String(docID))
|
||||
}
|
||||
}
|
||||
return knowledge, sources
|
||||
}
|
||||
|
||||
// buildPromptMessages 构建提示词模板
|
||||
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
|
||||
promptTpl := prompt.FromMessages(
|
||||
schema.FString,
|
||||
&schema.Message{
|
||||
Role: schema.System,
|
||||
// Content: `你是专业的客服助手,语气友好。
|
||||
//如果参考知识中有相关信息,请优先依据参考知识回答。
|
||||
//如果没有相关信息,就正常回答,不要说无法回答。
|
||||
//
|
||||
//参考知识:
|
||||
//{knowledge}`,
|
||||
Content: `你是专业的客服助手,语气友好。
|
||||
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
|
||||
|
||||
参考知识:
|
||||
{knowledge}`,
|
||||
},
|
||||
&schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "{question}",
|
||||
},
|
||||
)
|
||||
|
||||
return promptTpl.Format(ctx, map[string]any{
|
||||
"knowledge": knowledge,
|
||||
"question": question,
|
||||
})
|
||||
}
|
||||
|
||||
// streamGenerateAnswer 流式生成
|
||||
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
|
||||
|
||||
sr, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream failed: %w", err)
|
||||
}
|
||||
|
||||
var chunks []*schema.Message
|
||||
for {
|
||||
chunk, err := sr.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream recv failed: %w", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return schema.ConcatMessages(chunks)
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Red Future Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/net/gclient"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var (
|
||||
// 千问API默认配置
|
||||
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||
defaultTimeout = 10 * time.Minute
|
||||
defaultRetryTimes = 2
|
||||
)
|
||||
|
||||
type QwenEmbeddingConfig struct {
|
||||
// Timeout specifies the maximum duration to wait for API responses
|
||||
// Optional. Default: 10 minutes
|
||||
Timeout *time.Duration `json:"timeout"`
|
||||
|
||||
// HTTPClient specifies the client to send HTTP requests.
|
||||
// Optional. Default &http.Client{Timeout: Timeout}
|
||||
HTTPClient *http.Client `json:"http_client"`
|
||||
|
||||
// RetryTimes specifies the number of retry attempts for failed API calls
|
||||
// Optional. Default: 2
|
||||
RetryTimes *int `json:"retry_times"`
|
||||
|
||||
// BaseURL specifies the base URL for Qwen DashScope service
|
||||
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||
BaseURL string `json:"base_url"`
|
||||
|
||||
// APIKey specifies the API Key for authentication
|
||||
// Required
|
||||
APIKey string `json:"api_key"`
|
||||
|
||||
// Model specifies the model name for Qwen embedding
|
||||
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
|
||||
Model string `json:"model"`
|
||||
|
||||
// TextType specifies the type of text: "document" or "query"
|
||||
// Optional. Default: "document"
|
||||
TextType string `json:"text_type"`
|
||||
|
||||
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
|
||||
// Optional. Default: 5
|
||||
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
|
||||
}
|
||||
|
||||
type QwenEmbedder struct {
|
||||
client *gclient.Client
|
||||
conf *QwenEmbeddingConfig
|
||||
}
|
||||
|
||||
// EmbeddingRequest 千问embedding请求结构
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse 千问embedding响应结构
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []struct {
|
||||
TextIndex int `json:"text_index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
|
||||
}
|
||||
|
||||
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
|
||||
if len(config.BaseURL) == 0 {
|
||||
config.BaseURL = defaultBaseURL
|
||||
}
|
||||
if config.Timeout == nil {
|
||||
config.Timeout = &defaultTimeout
|
||||
}
|
||||
if config.RetryTimes == nil {
|
||||
defaultRetryTimes := 2
|
||||
config.RetryTimes = &defaultRetryTimes
|
||||
}
|
||||
if len(config.TextType) == 0 {
|
||||
config.TextType = "document"
|
||||
}
|
||||
if config.MaxConcurrentRequests == nil {
|
||||
defaultMaxConcurrentRequests := 5
|
||||
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
|
||||
}
|
||||
|
||||
client := g.Client()
|
||||
client.SetTimeout(*config.Timeout)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
|
||||
if len(config.APIKey) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] APIKey is required")
|
||||
}
|
||||
if len(config.Model) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] Model is required")
|
||||
}
|
||||
|
||||
client := buildQwenClient(config)
|
||||
|
||||
return &QwenEmbedder{
|
||||
client: client,
|
||||
conf: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
|
||||
[][]float64, error) {
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
|
||||
}
|
||||
|
||||
options := embedding.GetCommonOptions(&embedding.Options{
|
||||
Model: &e.conf.Model,
|
||||
}, opts...)
|
||||
|
||||
conf := &embedding.Config{
|
||||
Model: dereferenceOrZero(options.Model),
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
|
||||
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
|
||||
Texts: texts,
|
||||
Config: conf,
|
||||
})
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
var usage *embedding.TokenUsage
|
||||
var embeddings [][]float64
|
||||
var err error
|
||||
|
||||
// 调用千问API获取embedding
|
||||
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
|
||||
Embeddings: embeddings,
|
||||
Config: conf,
|
||||
TokenUsage: usage,
|
||||
})
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
|
||||
// 构建请求
|
||||
var req EmbeddingRequest
|
||||
req.Model = e.conf.Model
|
||||
req.Input.Texts = texts
|
||||
req.Parameters.TextType = e.conf.TextType
|
||||
|
||||
// 调用API
|
||||
client := e.client.Clone()
|
||||
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
|
||||
client.SetHeader("Content-Type", "application/json")
|
||||
client.SetTimeout(*e.conf.Timeout)
|
||||
|
||||
resp, err := client.Post(ctx, e.conf.BaseURL, req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Close()
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp APIError
|
||||
result := resp.ReadAll()
|
||||
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
|
||||
return nil, nil, &errResp
|
||||
}
|
||||
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var apiResp EmbeddingResponse
|
||||
result := resp.ReadAll()
|
||||
if err = gconv.Struct(result, &apiResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应结果
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, emb := range apiResp.Output.Embeddings {
|
||||
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
|
||||
embeddings[emb.TextIndex] = emb.Embedding
|
||||
}
|
||||
}
|
||||
|
||||
usage := &embedding.TokenUsage{
|
||||
TotalTokens: apiResp.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
|
||||
|
||||
return embeddings, usage, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) GetType() string {
|
||||
return getType()
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func getType() string {
|
||||
return "Qwen"
|
||||
}
|
||||
|
||||
func dereferenceOrZero[T any](v *T) T {
|
||||
if v == nil {
|
||||
var t T
|
||||
return t
|
||||
}
|
||||
return *v
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskPriority 任务优先级
|
||||
type TaskPriority string
|
||||
|
||||
const (
|
||||
TaskPriorityLow TaskPriority = "low" // 低优先级
|
||||
TaskPriorityMedium TaskPriority = "medium" // 中优先级
|
||||
TaskPriorityHigh TaskPriority = "high" // 高优先级
|
||||
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
|
||||
)
|
||||
@@ -3,6 +3,8 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"rag/dao"
|
||||
"sort"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
@@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct {
|
||||
Embedder embedding.Embedder
|
||||
DefaultTopK int
|
||||
DefaultIndex string
|
||||
DSLInfo map[string]any
|
||||
}
|
||||
|
||||
type PGVectorRetriever struct {
|
||||
embedder embedding.Embedder
|
||||
topK int
|
||||
index string
|
||||
dslInfo map[string]any
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
@@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever,
|
||||
embedder: config.Embedder,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
dslInfo: config.DSLInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 处理公共 Option(官方标准写法)
|
||||
options := &retriever.Options{
|
||||
Index: &r.index,
|
||||
TopK: &r.topK,
|
||||
DSLInfo: r.dslInfo,
|
||||
Embedding: r.embedder,
|
||||
}
|
||||
options = retriever.GetCommonOptions(options, opts...)
|
||||
|
||||
// 2. 回调(官方标准)
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: query,
|
||||
TopK: *options.TopK,
|
||||
})
|
||||
|
||||
// 3. 执行检索
|
||||
docs, err := r.doRetrieve(ctx, query, options)
|
||||
// ==========================================
|
||||
// 🔥 双路检索:向量 + 全文
|
||||
// ==========================================
|
||||
docsVector, err := r.doRetrieveVector(ctx, query, options)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 完成回调
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
||||
Docs: docs,
|
||||
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 + 去重
|
||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
|
||||
// 排序(distance 越小越靠前)
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
|
||||
// 最多保留 topK
|
||||
if len(docs) > *options.TopK {
|
||||
docs = docs[:*options.TopK]
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 生成向量
|
||||
// ==========================================
|
||||
// 1. 向量检索(PG)
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
|
||||
return nil, errors.New("empty query vector")
|
||||
}
|
||||
|
||||
queryVec := pgvector.NewVector(vectors[0])
|
||||
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
// 2. PG 向量相似度检索 SQL
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <-> ? AS distance
|
||||
FROM document_chunk
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
// 3. 查询
|
||||
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
||||
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 转为 Eino Document
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
Metadata: map[string]any{
|
||||
"dataset_id": row["dataset_id"],
|
||||
"document_id": row["document_id"],
|
||||
"distance": row["distance"],
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": gconv.Float64(row["distance"]),
|
||||
"retrieve_by": "vector",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 全文检索(Meilisearch)🔥 新增
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
// 调用你已有的 Meilisearch DAO
|
||||
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": 0.1, // 全文结果给高分
|
||||
"retrieve_by": "fulltext",
|
||||
},
|
||||
})
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 合并去重
|
||||
// ==========================================
|
||||
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||||
idMap := make(map[string]*schema.Document)
|
||||
for _, d := range vecDocs {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
for _, d := range fullDocs {
|
||||
if _, exists := idMap[d.ID]; !exists {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
}
|
||||
merged := make([]*schema.Document, 0, len(idMap))
|
||||
for _, d := range idMap {
|
||||
merged = append(merged, d)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskStatus 任务状态
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "pending" // 待处理
|
||||
TaskStatusRunning TaskStatus = "running" // 运行中
|
||||
TaskStatusCompleted TaskStatus = "completed" // 已完成
|
||||
TaskStatusFailed TaskStatus = "failed" // 失败
|
||||
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskType 任务类型
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
|
||||
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
|
||||
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
|
||||
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
|
||||
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
|
||||
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
|
||||
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
|
||||
)
|
||||
@@ -1,114 +0,0 @@
|
||||
package gse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/go-ego/gse"
|
||||
"github.com/go-ego/gse/hmm/extracker"
|
||||
"github.com/go-ego/gse/hmm/segment"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
)
|
||||
|
||||
var GseTool *gseTool
|
||||
|
||||
// 初始化函数:程序启动时执行一次
|
||||
func init() {
|
||||
var err error
|
||||
GseTool, err = newGseTool()
|
||||
if err != nil {
|
||||
glog.Error(context.Background(), err)
|
||||
}
|
||||
}
|
||||
|
||||
// gseTool 关键词提取工具(gse v1.0.2 标准)
|
||||
type gseTool struct {
|
||||
seg gse.Segmenter
|
||||
tfidf *extracker.TagExtracter
|
||||
tr *extracker.TextRanker
|
||||
}
|
||||
|
||||
// newGseTool 初始化工具(内置词典 + 停用词)
|
||||
func newGseTool() (tool *gseTool, err error) {
|
||||
// 1. 初始化分词器
|
||||
var seg gse.Segmenter
|
||||
// 内置词典(无外部文件)
|
||||
err = seg.LoadDictEmbed()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// 内置停用词(v1.0.2 标准)
|
||||
err = seg.LoadStopEmbed()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 初始化 TF-IDF 提取器
|
||||
tfidf := &extracker.TagExtracter{}
|
||||
tfidf.WithGse(seg)
|
||||
err = tfidf.LoadIdf()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 初始化 TextRank 提取器
|
||||
tr := &extracker.TextRanker{}
|
||||
tr.WithGse(seg)
|
||||
|
||||
tool = &gseTool{
|
||||
seg: seg,
|
||||
tfidf: tfidf,
|
||||
tr: tr,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM)
|
||||
func (k *gseTool) Cut(text string) []string {
|
||||
return k.seg.Cut(text, true)
|
||||
}
|
||||
|
||||
// Keyword 最终输出:关键词 + 权重
|
||||
type Keyword struct {
|
||||
Word string `json:"word"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
func (k *gseTool) Extract(text string, topN int) []Keyword {
|
||||
// 1. 提取 TF-IDF
|
||||
tfTags := k.extractTFIDF(text, topN)
|
||||
|
||||
// 2. 提取 TextRank
|
||||
trTags := k.extractTextRank(text, topN)
|
||||
|
||||
// 3. 合并成最终关键词(业务最常用)
|
||||
scoreMap := make(map[string]float64)
|
||||
for _, tag := range tfTags {
|
||||
scoreMap[tag.Text] = tag.Weight
|
||||
}
|
||||
for _, tag := range trTags {
|
||||
scoreMap[tag.Text] = tag.Weight
|
||||
}
|
||||
|
||||
// 转成切片并排序(高分在前)
|
||||
res := make([]Keyword, 0, len(scoreMap))
|
||||
for word, score := range scoreMap {
|
||||
res = append(res, Keyword{Word: word, Score: score})
|
||||
}
|
||||
|
||||
sort.Slice(res, func(i, j int) bool {
|
||||
return res[i].Score > res[j].Score
|
||||
})
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// ExtractTFIDF TF-IDF 关键词(带权重)90% 业务:文章标签、搜索、关键词
|
||||
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
|
||||
return k.tfidf.ExtractTags(text, topN)
|
||||
}
|
||||
|
||||
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
|
||||
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
|
||||
return k.tr.TextRank(text, topN)
|
||||
}
|
||||
69
common/task/base_task.go
Normal file
69
common/task/base_task.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type baseTaskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskType string
|
||||
Status string
|
||||
Priority string
|
||||
ParentTaskID string
|
||||
TotalItems string
|
||||
ProcessedItems string
|
||||
Progress string
|
||||
StartTime string
|
||||
EndTime string
|
||||
Duration string
|
||||
SuccessCount string
|
||||
FailCount string
|
||||
Executor string
|
||||
DocumentID string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var BaseTaskCol = baseTaskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskType: "task_type",
|
||||
Status: "status",
|
||||
Priority: "task_priority",
|
||||
ParentTaskID: "parent_task_id",
|
||||
TotalItems: "total_items",
|
||||
ProcessedItems: "processed_items",
|
||||
Progress: "progress",
|
||||
StartTime: "start_time",
|
||||
EndTime: "end_time",
|
||||
Duration: "duration",
|
||||
SuccessCount: "success_count",
|
||||
FailCount: "fail_count",
|
||||
Executor: "executor",
|
||||
DocumentID: "document_id",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
// SQLBaseTask 任务基类 - SQL版本
|
||||
type SQLBaseTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
// 任务核心信息
|
||||
TaskType TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||
Status TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||
Priority TaskPriority `orm:"task_priority" json:"priority,omitempty" dc:"任务优先级"`
|
||||
ParentTaskID int64 `orm:"parent_task_id" json:"parentTaskId,omitempty" dc:"父任务ID"`
|
||||
// 任务进度
|
||||
TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总数"`
|
||||
ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理数"`
|
||||
Progress float64 `orm:"progress" json:"progress" dc:"进度"` // 0~100 百分比
|
||||
// 任务结果
|
||||
StartTime *time.Time `orm:"start_time" json:"startTime" dc:"开始时间"`
|
||||
EndTime *time.Time `orm:"end_time" json:"endTime,omitempty" dc:"结束时间"`
|
||||
Duration int64 `orm:"duration" json:"duration,omitempty" dc:"耗时(毫秒)"`
|
||||
SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||
FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||
// 其他
|
||||
Executor string `orm:"executor" json:"executor,omitempty" dc:"执行器标识"`
|
||||
DocumentID int64 `orm:"document_id" json:"documentId,omitempty" dc:"文档ID"`
|
||||
Remark string `orm:"remark" json:"remark,omitempty" dc:"备注/错误信息"`
|
||||
}
|
||||
30
common/task/consts.go
Normal file
30
common/task/consts.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package task
|
||||
|
||||
// TaskType 任务类型枚举:文档解析的三个子任务
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词
|
||||
TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量
|
||||
TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索
|
||||
TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务
|
||||
)
|
||||
|
||||
// TaskStatus 任务状态枚举
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "PENDING" // 待执行
|
||||
TaskStatusRunning TaskStatus = "RUNNING" // 执行中
|
||||
TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成
|
||||
TaskStatusFailed TaskStatus = "FAILED" // 执行失败
|
||||
)
|
||||
|
||||
// TaskPriority 任务优先级
|
||||
type TaskPriority int
|
||||
|
||||
const (
|
||||
TaskPriorityLow TaskPriority = 1 // 低
|
||||
TaskPriorityMedium TaskPriority = 2 // 中
|
||||
TaskPriorityHigh TaskPriority = 3 // 高
|
||||
)
|
||||
24
config.yml
24
config.yml
@@ -48,10 +48,10 @@ database:
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_knowledge:
|
||||
- type: "pgsql"
|
||||
host: "localhost"
|
||||
port: "5432"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "123456"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master"
|
||||
@@ -69,10 +69,10 @@ database:
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_vector:
|
||||
- type: "pgsql"
|
||||
host: "localhost"
|
||||
port: "5432"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "123456"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_vector_" # (可选)表名前缀
|
||||
role: "master"
|
||||
@@ -91,14 +91,14 @@ database:
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: "localhost:6379"
|
||||
address: "116.204.74.41:6379"
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: localhost:8500
|
||||
address: 116.204.74.41:8500
|
||||
|
||||
jaeger:
|
||||
addr: localhost:4318
|
||||
addr: 116.204.74.41:4318
|
||||
|
||||
# eino框架配置
|
||||
eino:
|
||||
@@ -115,6 +115,10 @@ eino:
|
||||
# apiType: "multi_modal_api"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "text-embedding-v3"
|
||||
chatmodel:
|
||||
provider: "dashscope"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "qwen-turbo"
|
||||
|
||||
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
||||
filePrefix: "http://116.204.74.41:9000"
|
||||
@@ -122,7 +126,7 @@ filePrefix: "http://116.204.74.41:9000"
|
||||
gmq:
|
||||
redis:
|
||||
primary:
|
||||
addr: "localhost"
|
||||
addr: "116.204.74.41"
|
||||
port: "6379"
|
||||
db: 0
|
||||
username: ""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package public
|
||||
|
||||
// 数据库名称
|
||||
const (
|
||||
DbNameKnowledge = "rag_knowledge"
|
||||
DbNameVector = "rag_vector"
|
||||
@@ -10,6 +11,7 @@ const (
|
||||
TableNameDocument = "document"
|
||||
TableNameDataset = "dataset"
|
||||
TableNameKeyword = "keyword"
|
||||
TableNameTask = "task"
|
||||
TableNameDatasetIndex = "dataset_index"
|
||||
TableNameDocumentChunk = "document_chunk"
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ func (c *document) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto
|
||||
}
|
||||
|
||||
// Process 处理文件(向量化)
|
||||
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
res, err = service.Document.Process(ctx, req)
|
||||
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Document.Process(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
17
controller/rag_query.go
Normal file
17
controller/rag_query.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/model/dto"
|
||||
"rag/service"
|
||||
)
|
||||
|
||||
type ragQuery struct{}
|
||||
|
||||
var RAGQuery = new(ragQuery)
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (c *ragQuery) Query(ctx context.Context, req *dto.RAGQueryReq) (res *dto.RAGQueryRes, err error) {
|
||||
res, err = service.RAGQuery.Query(ctx, req)
|
||||
return
|
||||
}
|
||||
@@ -49,6 +49,7 @@ func (d *datasetIndexDao) InsertIndex(ctx context.Context, indexName string) (er
|
||||
CREATE INDEX IF NOT EXISTS %s
|
||||
ON %s
|
||||
USING ivfflat (vector vector_cosine_ops)
|
||||
WITH (lists = 100)
|
||||
WHERE vector IS NOT NULL;
|
||||
`, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk)
|
||||
_, err = db.Exec(ctx, sqlStr)
|
||||
|
||||
@@ -2,12 +2,17 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/public"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
var DocumentChunk = new(documentChunkDao)
|
||||
@@ -55,3 +60,56 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR
|
||||
err = r.Structs(&res)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <-> ? AS distance
|
||||
FROM rag_vector_document_chunk
|
||||
WHERE dataset_id IN (?)
|
||||
AND vector IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
// 顺序:vector, dataset_id, topK
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryVec, datasetId, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.List(), nil
|
||||
}
|
||||
|
||||
// SearchByKeywords 通过关键词全文检索文档块
|
||||
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||||
// 构建 meilisearch 查询参数
|
||||
searchParams := &meilisearch.SearchParams{
|
||||
Query: query,
|
||||
Limit: int64(topK),
|
||||
}
|
||||
|
||||
// 构建 datasetIds 过滤条件
|
||||
if len(datasetIds) > 0 {
|
||||
datasetIdStrs := gconv.Strings(datasetIds)
|
||||
quotedIds := make([]string, len(datasetIdStrs))
|
||||
for i, id := range datasetIdStrs {
|
||||
quotedIds[i] = fmt.Sprintf("%s", id)
|
||||
}
|
||||
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
var hits []map[string]interface{}
|
||||
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换查询结果为 gdb.List
|
||||
resultList := make(gdb.List, 0, len(hits))
|
||||
for _, hit := range hits {
|
||||
resultList = append(resultList, hit)
|
||||
}
|
||||
|
||||
return resultList, nil
|
||||
}
|
||||
|
||||
@@ -82,6 +82,9 @@ func (d *keywordDao) List(ctx context.Context, req *dto.ListKeywordReq, fields .
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%")
|
||||
}
|
||||
model.WhereIn(entity.KeywordCol.Word, req.Words)
|
||||
model.Where(entity.KeywordCol.DatasetId, req.DatasetId)
|
||||
model.Where(entity.KeywordCol.DocumentId, req.DocumentId)
|
||||
model.OrderDesc(entity.KeywordCol.Weight)
|
||||
model.OrderDesc(entity.KeywordCol.CreatedAt)
|
||||
if req.Page != nil {
|
||||
|
||||
58
dao/task.go
Normal file
58
dao/task.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/consts/public"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = new(taskDao)
|
||||
|
||||
type taskDao struct{}
|
||||
|
||||
// Insert 创建任务
|
||||
func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, err error) {
|
||||
var res *entity.Task
|
||||
if err = gconv.Struct(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新任务
|
||||
func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask)
|
||||
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *taskDao) Get(ctx context.Context, req *dto.GetTaskReq) (res []*entity.Task, total int, err error) {
|
||||
r, total, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.TaskCol.Id, req.Id).
|
||||
Where(entity.TaskCol.TaskId, req.TaskId).
|
||||
Where(entity.TaskCol.TaskType, req.TaskType).AllAndCount(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Structs(&res)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *taskDao) DeleteByTaskId(ctx context.Context, req *dto.DeleteTaskByTaskIdReq) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Where(entity.TaskCol.TaskId, req.TaskId).Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
9
go.mod
9
go.mod
@@ -16,9 +16,9 @@ require (
|
||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||
github.com/elastic/go-elasticsearch/v8 v8.16.0
|
||||
github.com/go-ego/gse v1.0.2
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||
github.com/gogf/gf/v2 v2.10.0
|
||||
github.com/golang/glog v1.2.5
|
||||
@@ -50,7 +50,7 @@ require (
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 // indirect
|
||||
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
@@ -64,6 +64,7 @@ require (
|
||||
github.com/fatih/color v1.19.0 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/go-ego/gse v1.0.2 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -105,7 +106,7 @@ require (
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 // indirect
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2 // indirect
|
||||
github.com/meilisearch/meilisearch-go v0.36.1 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
@@ -134,7 +135,7 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199 // indirect
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181 // indirect
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect
|
||||
github.com/xuri/excelize/v2 v2.9.0 // indirect
|
||||
|
||||
17
go.sum
17
go.sum
@@ -33,8 +33,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
||||
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
||||
gitea.com/red-future/common v0.0.11 h1:AV7W3G0uZ8aPpHHSHd4ZHmLWe5+2STPKe/AYPoPCWVc=
|
||||
gitea.com/red-future/common v0.0.11/go.mod h1:B8syUI4XbLCDQSeRHURYxEwnWw8mEFgmqCxjC+lM+NU=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
@@ -158,10 +156,12 @@ github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419/go.mod h1:SajSFFRIXJXIbxadAAlSUIS5KTY8R/jzJg9RNSOXCCI=
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 h1:vZ3dL8xwo2sy73aBVKs4AJiO5OCHRxMOJUwIYkp0CWs=
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:+oI0sr0rA0OHCxaQJ0rzMYld3LAODHhPKzBx5JYCya0=
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7 h1:8c1LB5lH+dERbf2twp18B1Y822JOQSsS6x7Vnksehk0=
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7/go.mod h1:n4iuIUQeL3D8GRsGAhkgceRZpoyPQbqOXFMXM2Q4hNY=
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 h1:Sl6giB1SJlA+ZlO0gzPH05IsUORtdYYPN6GiyH1B9MA=
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:H4kNmiTe2irnvipVNIP4q8yqXf2fZ6v24krvQYBtYb8=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 h1:yOZII6VYaL00CVZYba+HUixFygsW0Xz/1QjQ5htj1Ls=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14/go.mod h1:1xMQZ8eE11pkEoTAEy8UlaAY817qGVMvjpDPGSIO3Ns=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 h1:LbdSG9+qWzzp9RFW6dSFkaUW171JvCoYn/K63zX6dQE=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
@@ -531,8 +531,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 h1:u/IMMgrj/d617Dh/8BKAwlcstD74ynOJzCtVl+y8xAs=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2 h1:iXombGGjqjBrmE9WaSidUhhi3YQhf42QTHvHLMkgvCA=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
||||
github.com/meilisearch/meilisearch-go v0.36.1 h1:mJTCJE5g7tRvaqKco6DfqOuJEjX+rRltDEnkEC02Y0M=
|
||||
github.com/meilisearch/meilisearch-go v0.36.1/go.mod h1:hWcR0MuWLSzHfbz9GGzIr3s9rnXLm1jqkmHkJPbUSvM=
|
||||
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
|
||||
@@ -735,8 +735,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199 h1:zv9QOqTl/IsLwtfC37GlJtcz6vMAHi+pjq8ILWjLYUc=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181 h1:/3PB4M1N4fjMqiSKTJwX43EZ5Nn1HUOtQrSCk+22+wI=
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181/go.mod h1:gfEDc1s7SYaGoY+WH2dRrS3qiuDJMkwqyfXWCa7+7oA=
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9 h1:du2gnImtyWXKkQFnJW/GXCs+UBibGGOXIbP1Ams2pB8=
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||
@@ -1193,6 +1193,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
|
||||
20
main.go
20
main.go
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
"gitea.com/red-future/common/utils"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
@@ -27,22 +28,17 @@ func main() {
|
||||
controller.Dataset,
|
||||
controller.Document,
|
||||
controller.DocumentChunk,
|
||||
controller.Keyword,
|
||||
controller.RAGQuery,
|
||||
})
|
||||
|
||||
gmq.Init("config.yml")
|
||||
|
||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
SubMessage: types.SubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
ConsumerName: public.KnowledgeDocumentVectorStatusConsumer,
|
||||
AutoAck: public.KnowledgeDocumentVectorStatusAutoAck,
|
||||
FetchCount: public.KnowledgeDocumentVectorStatusBatchSize,
|
||||
HandleFunc: service.Document.DocsVectorStatusMsg,
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
err := utils.InitGseTool(ctx)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "gse 分词工具初始化失败:", err)
|
||||
}
|
||||
|
||||
gmq.Init("config.yml")
|
||||
|
||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
SubMessage: types.SubMessage{
|
||||
Topic: public.KnowledgeDocumentChunkTopic,
|
||||
|
||||
@@ -84,12 +84,6 @@ type ProcessDocumentReq struct {
|
||||
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
||||
}
|
||||
|
||||
// ProcessDocumentRes 处理文件响应
|
||||
type ProcessDocumentRes struct {
|
||||
ChunkCount int64 `json:"chunkCount"`
|
||||
CostTime int64 `json:"costTime"`
|
||||
}
|
||||
|
||||
type ListDocumentChunkRPC struct {
|
||||
List []*DocumentChunkRPC `json:"list"`
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ type ListKeywordReq struct {
|
||||
DatasetId int64 `json:"datasetId"`
|
||||
DocumentId int64 `json:"documentId"`
|
||||
Word string `json:"word"`
|
||||
Words []string `json:"words"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
}
|
||||
|
||||
@@ -62,9 +63,11 @@ type ListKeywordRes struct {
|
||||
}
|
||||
|
||||
type KeywordVO struct {
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
|
||||
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
|
||||
21
model/dto/rag_query.go
Normal file
21
model/dto/rag_query.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// RAGQueryReq RAG查询请求
|
||||
type RAGQueryReq struct {
|
||||
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
|
||||
|
||||
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
|
||||
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
|
||||
TopK int `json:"topK" d:"5" dc:"检索topK,默认5"`
|
||||
}
|
||||
|
||||
// RAGQueryRes RAG查询响应
|
||||
type RAGQueryRes struct {
|
||||
Answer string `json:"answer" dc:"生成的答案"`
|
||||
DatasetId string `json:"datasetId" dc:"使用的数据集ID"`
|
||||
Sources []string `json:"sources" dc:"参考来源"`
|
||||
}
|
||||
65
model/dto/task.go
Normal file
65
model/dto/task.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"rag/common/task"
|
||||
)
|
||||
|
||||
// WriteTaskProgressReq 写入任务进度请求
|
||||
type WriteTaskProgressReq struct {
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// CreateTaskReq 创建任务请求
|
||||
type CreateTaskReq struct {
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// UpdateTaskReq 更新任务请求
|
||||
type UpdateTaskReq struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// DeleteTaskByTaskIdReq 删除任务请求
|
||||
type DeleteTaskByTaskIdReq struct {
|
||||
TaskId int64 `json:"taskId" v:"required#任务id不能为空"`
|
||||
}
|
||||
|
||||
// GetTaskReq 获取任务请求
|
||||
type GetTaskReq struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
}
|
||||
|
||||
// TaskVO 任务视图对象
|
||||
type TaskVO struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
Priority task.TaskPriority `json:"priority" dc:"任务优先级"`
|
||||
ParentTaskID int64 `json:"parentTaskId" dc:"父任务ID"`
|
||||
TotalItems int64 `json:"totalItems" dc:"总项数"`
|
||||
ProcessedItems int64 `json:"processedItems" dc:"已处理项数"`
|
||||
Progress float64 `json:"progress" dc:"进度百分比"`
|
||||
StartTime *int64 `json:"startTime" dc:"开始时间戳"`
|
||||
EndTime *int64 `json:"endTime" dc:"结束时间戳"`
|
||||
Duration int64 `json:"duration" dc:"耗时(毫秒)"`
|
||||
SuccessCount int64 `json:"successCount" dc:"成功数"`
|
||||
FailCount int64 `json:"failCount" dc:"失败数"`
|
||||
Executor string `json:"executor" dc:"执行器"`
|
||||
DocumentID int64 `json:"documentId" dc:"文档ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
Creator string `json:"creator" dc:"创建人"`
|
||||
CreatedAt int64 `json:"createdAt" dc:"创建时间"`
|
||||
Updater string `json:"updater" dc:"更新人"`
|
||||
UpdatedAt int64 `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
66
model/entity/task.go
Normal file
66
model/entity/task.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"rag/common/task"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type taskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskId string
|
||||
TaskType string
|
||||
Status string
|
||||
Executor string
|
||||
Remark string
|
||||
//Priority string
|
||||
//ParentTaskId string
|
||||
//TotalItems string
|
||||
//ProcessedItems string
|
||||
//Progress string
|
||||
//StartTime string
|
||||
//EndTime string
|
||||
//Duration string
|
||||
//SuccessCount string
|
||||
//FailCount string
|
||||
}
|
||||
|
||||
var TaskCol = taskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskId: "task_id",
|
||||
TaskType: "task_type",
|
||||
Status: "status",
|
||||
Executor: "executor",
|
||||
Remark: "remark",
|
||||
//Priority: "priority",
|
||||
//ParentTaskId: "parent_task_id",
|
||||
//TotalItems: "total_items",
|
||||
//ProcessedItems: "processed_items",
|
||||
//Progress: "progress",
|
||||
//StartTime: "start_time",
|
||||
//EndTime: "end_time",
|
||||
//Duration: "duration",
|
||||
//SuccessCount: "success_count",
|
||||
//FailCount: "fail_count",
|
||||
}
|
||||
|
||||
// Task 任务记录表
|
||||
type Task struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
|
||||
TaskId int64 `orm:"task_id" json:"taskId" dc:"任务ID"`
|
||||
TaskType task.TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||
Executor string `orm:"executor" json:"executor" dc:"执行器"`
|
||||
Remark string `orm:"remark" json:"remark" dc:"备注"`
|
||||
//Priority task.TaskPriority `orm:"priority" json:"priority" dc:"任务优先级"`
|
||||
//ParentTaskId int64 `orm:"parent_task_id" json:"parentTaskId" dc:"父任务ID"`
|
||||
//TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总项数"`
|
||||
//ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理项数"`
|
||||
//SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||
//FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||
//Progress float64 `orm:"progress" json:"progress" dc:"进度百分比"`
|
||||
//StartTime *gtime.Time `orm:"start_time" json:"startTime" dc:"开始时间戳"`
|
||||
//EndTime *gtime.Time `orm:"end_time" json:"endTime" dc:"结束时间戳"`
|
||||
//Duration int64 `orm:"duration" json:"duration" dc:"耗时(毫秒)"`
|
||||
}
|
||||
@@ -5,17 +5,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/common/gse"
|
||||
"rag/common/task"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"gitea.com/red-future/common/http"
|
||||
@@ -29,6 +26,7 @@ import (
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -54,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
|
||||
return
|
||||
}
|
||||
res = &dto.CreateDocumentRes{Id: id}
|
||||
|
||||
// 写入任务进度待处理 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusPending,
|
||||
Remark: "文档上传成功待解析: " + req.Title,
|
||||
})
|
||||
return
|
||||
})
|
||||
|
||||
@@ -79,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
DocumentCount: -1,
|
||||
DocumentSize: -docs.FileSize,
|
||||
}
|
||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||||
if err != nil {
|
||||
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = dao.Document.Delete(ctx, req)
|
||||
|
||||
if _, err = dao.Document.Delete(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||||
TaskId: docs.Id,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
@@ -107,118 +120,159 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
|
||||
//eino.TestIndexer()
|
||||
//eino.TestRetriever()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Process 处理文件(使用eino框架切分和向量化)
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
|
||||
// 1. 查询文件信息
|
||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if g.IsEmpty(doc) {
|
||||
return nil, errors.New("document not found")
|
||||
return errors.New("document not found")
|
||||
}
|
||||
|
||||
// 2. 使用eino框架进行文件切分(并发执行)
|
||||
var vectorDocsCount, chunks int64
|
||||
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
|
||||
var err1, err2, err3 error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
// 任务1
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务2
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err2 = s.esSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务3
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err3 = s.extractDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 直接等待,不使用通道,避免泄漏
|
||||
wg.Wait()
|
||||
|
||||
// 2. 更新文档状态为处理中
|
||||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||||
updateDocumentReq.Id = req.Id
|
||||
|
||||
// 统一判断错误
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
// 更新文档状态
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, err3
|
||||
}
|
||||
|
||||
// 4. 更新文件状态为处理中和切分数量
|
||||
if vectorDocsCount > 0 {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
} else {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
|
||||
}
|
||||
updateDocumentReq.ChunkCount = chunks
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
// 写入任务进度失败 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "更新文档状态失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
costTime := time.Since(startTime).Milliseconds()
|
||||
|
||||
return &dto.ProcessDocumentRes{
|
||||
ChunkCount: chunks,
|
||||
CostTime: costTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
// 写入任务进度进行中 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "文档解析开始",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var words []gse.Keyword
|
||||
// ======================
|
||||
// 核心:grpool + g.Try 最佳实践
|
||||
// ======================
|
||||
taskCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
// 任务1: SQL 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
// 任务2: ES 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
// 任务3: 提取文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDocument 关键词提取(支持取消)
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var words []utils.Keyword
|
||||
if len(docs[0].Content) < 500 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 4)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 4)
|
||||
} else if len(docs[0].Content) < 2000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 8)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 8)
|
||||
} else if len(docs[0].Content) < 5000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 13)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 13)
|
||||
} else {
|
||||
var docsSplit []*schema.Document
|
||||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "递归分割文档失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
for _, t := range docsSplit {
|
||||
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量操作前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||
for _, word := range words {
|
||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||
@@ -231,37 +285,111 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
||||
if len(keywordReqs) > 0 {
|
||||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "关键字存储失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "关键字提取完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "没有提取到关键词,关键字提取完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
|
||||
// sqlSplitDocument SQL切分(支持取消)
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 语义切分文件
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
docsSplitCount = gconv.Int64(len(docsSplit))
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档
|
||||
var docsChunk = make([]*schema.Document, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
@@ -277,6 +405,18 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
t.MetaData = metaData
|
||||
docsChunk = append(docsChunk, t)
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量发送前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 发送消息到队列
|
||||
if len(docsChunk) > 0 {
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
@@ -285,41 +425,117 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
||||
Data: docsChunk,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "发送消息到队列失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度进行中 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "向量生成任务已提交到队列",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成向量,任务完成",
|
||||
})
|
||||
}
|
||||
vectorDocsCount = gconv.Int64(len(docsChunk))
|
||||
return
|
||||
}
|
||||
|
||||
// esSplitDocument ES切分(支持取消)
|
||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 递归切分文件
|
||||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档并同时构建meilisearch文档
|
||||
var meiliDocs = make([]interface{}, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
// 构建Meilisearch文档
|
||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||
entity.DocumentChunkCol.Id: contentHash,
|
||||
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
||||
@@ -329,12 +545,45 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
||||
entity.DocumentChunkCol.ChunkIndex: i,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量写入前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 写入到meilisearch数据库中
|
||||
if len(meiliDocs) > 0 {
|
||||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||||
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
|
||||
// 写入任务进度失败 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "写入meilisearch失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "全文检索数据写入完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成全文检索数据,任务完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -467,20 +716,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa
|
||||
success = val.Bool()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = new(dto.KnowledgeDocumentMsg)
|
||||
if err = gconv.Struct(msg, &req); err != nil {
|
||||
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req.TenantId,
|
||||
UserName: req.Creator,
|
||||
})
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.Id,
|
||||
VectorStatus: req.VectorStatus,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,15 +3,11 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"rag/common/eino"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/common/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
@@ -22,10 +18,6 @@ var DocumentChunk = new(documentChunkService)
|
||||
|
||||
type documentChunkService struct{}
|
||||
|
||||
const (
|
||||
DatasetIndexStatusReady = "ready"
|
||||
)
|
||||
|
||||
// Update 更新文件块
|
||||
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
||||
_, err = dao.DocumentChunk.Update(ctx, req)
|
||||
@@ -60,32 +52,29 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId])
|
||||
creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator])
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// publishKnowledgeDocumentMsg 发布消息
|
||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
||||
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
|
||||
TenantId: tenantId,
|
||||
Creator: creator,
|
||||
Id: documentId,
|
||||
VectorStatus: vectorStatus,
|
||||
}
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
Data: knowledgeDocumentMsg,
|
||||
},
|
||||
// 写入任务进度成功 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "向量生成完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
52
service/rag_query.go
Normal file
52
service/rag_query.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/model/dto"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
)
|
||||
|
||||
var RAGQuery = new(ragQueryService)
|
||||
|
||||
type ragQueryService struct{}
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||
if req.TopK <= 0 {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
// 4. 使用向量检索器进行查询
|
||||
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||
Embedder: eino.EmbedderDashscope,
|
||||
DefaultTopK: req.TopK,
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 执行向量检索
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
}))
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.RAGQueryRes{
|
||||
Answer: replyMsg.Content,
|
||||
Sources: sources,
|
||||
}, nil
|
||||
}
|
||||
107
service/task.go
Normal file
107
service/task.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
|
||||
"rag/common/task"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = new(taskService)
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// WriteTaskProgress 写入任务进度(核心方法)
|
||||
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
|
||||
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
taskVO := make([]dto.TaskVO, 0, total)
|
||||
err = gconv.Struct(t, taskVO)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
taskVO = append(taskVO, dto.TaskVO{
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
})
|
||||
completed := IsAllSubTasksCompleted(taskVO)
|
||||
|
||||
// 1. 查询是否已存在该文档的该类型任务
|
||||
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 如果不存在,则创建新任务
|
||||
if g.IsEmpty(existTask) {
|
||||
createReq := &dto.CreateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, createReq)
|
||||
} else {
|
||||
// 3. 如果已存在,则更新任务
|
||||
updateReq := &dto.UpdateTaskReq{
|
||||
Id: existTask[0].Id,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Update(ctx, updateReq)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if completed {
|
||||
// 3. 如果已存在,则更新任务
|
||||
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
Status: task.TaskStatusCompleted,
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||
// 参数:传入当前文档的所有子任务列表
|
||||
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
|
||||
// 必须包含 3 种任务类型
|
||||
hasKeywords := false
|
||||
hasVector := false
|
||||
hasFullText := false
|
||||
|
||||
for _, t := range subTasks {
|
||||
// 子任务必须是【已完成】状态才计数
|
||||
if t.Status == task.TaskStatusCompleted {
|
||||
switch t.TaskType {
|
||||
case task.TaskTypeExtractKeywords:
|
||||
hasKeywords = true
|
||||
case task.TaskTypeGenerateVector:
|
||||
hasVector = true
|
||||
case task.TaskTypeFullTextSearch:
|
||||
hasFullText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 三个任务全部完成 → 返回true
|
||||
return hasKeywords && hasVector && hasFullText
|
||||
}
|
||||
44
update.sql
44
update.sql
@@ -114,6 +114,7 @@ COMMENT ON COLUMN rag_knowledge_document.file_path IS '文件存储路径(如M
|
||||
COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}';
|
||||
|
||||
--------------------pgsql创建rag_knowledge_document表语句---------------------------
|
||||
|
||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||
-- 关键词表(文档关键词+权重)
|
||||
CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
|
||||
@@ -161,6 +162,49 @@ COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
|
||||
|
||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||
|
||||
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||
-- 知识库任务表
|
||||
CREATE TABLE IF NOT EXISTS rag_knowledge_task (
|
||||
-- 基础字段(完全对齐项目规范)
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at timestamp(6),
|
||||
|
||||
-- 业务字段
|
||||
task_id BIGINT NOT NULL, -- 任务ID
|
||||
task_type VARCHAR(32) NOT NULL, -- 任务类型
|
||||
status VARCHAR(32) NOT NULL, -- 任务状态
|
||||
executor VARCHAR(128) DEFAULT '', -- 执行器
|
||||
remark TEXT DEFAULT '' -- 备注
|
||||
);
|
||||
|
||||
-- 索引(高频查询)
|
||||
CREATE INDEX idx_rkt_tenant_id ON rag_knowledge_task(tenant_id);
|
||||
CREATE INDEX idx_rkt_task_id ON rag_knowledge_task(task_id);
|
||||
CREATE INDEX idx_rkt_task_type ON rag_knowledge_task(task_type);
|
||||
CREATE INDEX idx_rkt_status ON rag_knowledge_task(status);
|
||||
CREATE INDEX idx_rkt_deleted_at ON rag_knowledge_task(deleted_at);
|
||||
|
||||
-- 表和字段注释
|
||||
COMMENT ON TABLE rag_knowledge_task IS '知识库任务表';
|
||||
COMMENT ON COLUMN rag_knowledge_task.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN rag_knowledge_task.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN rag_knowledge_task.creator IS '创建人';
|
||||
COMMENT ON COLUMN rag_knowledge_task.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN rag_knowledge_task.updater IS '更新人';
|
||||
COMMENT ON COLUMN rag_knowledge_task.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN rag_knowledge_task.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN rag_knowledge_task.task_id IS '任务ID';
|
||||
COMMENT ON COLUMN rag_knowledge_task.task_type IS '任务类型';
|
||||
COMMENT ON COLUMN rag_knowledge_task.status IS '任务状态';
|
||||
COMMENT ON COLUMN rag_knowledge_task.executor IS '执行器';
|
||||
COMMENT ON COLUMN rag_knowledge_task.remark IS '备注';
|
||||
|
||||
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||
|
||||
|
||||
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
|
||||
|
||||
Reference in New Issue
Block a user