refactor: 重构文档处理流程和任务管理
This commit is contained in:
166
common/eino/a.go
166
common/eino/a.go
@@ -1,166 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/components/prompt"
|
|
||||||
"github.com/cloudwego/eino/components/retriever"
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 1. 初始化三大组件
|
|
||||||
// ==========================================
|
|
||||||
// 1.1 向量检索(从知识库查客服知识)
|
|
||||||
ragRetriever := NewPGVectorRetriever()
|
|
||||||
|
|
||||||
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
|
|
||||||
chatTpl := newCustomerServiceTemplate()
|
|
||||||
|
|
||||||
// 1.3 大模型(ARK)
|
|
||||||
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
|
||||||
APIKey: os.Getenv("ARK_API_KEY"),
|
|
||||||
Model: os.Getenv("ARK_MODEL_ID"),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 2. 模拟会话:从 DB 读取历史对话
|
|
||||||
// ==========================================
|
|
||||||
sessionHistory := []*schema.Message{
|
|
||||||
{Role: schema.User, Content: "你们发什么快递?"},
|
|
||||||
{Role: schema.Assistant, Content: "默认发中通快递"},
|
|
||||||
{Role: schema.User, Content: "可以发顺丰吗?"},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 当前用户问题
|
|
||||||
userQuery := "那顺丰需要加钱吗?"
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 3. RAG 检索知识库
|
|
||||||
// ==========================================
|
|
||||||
docs, err := ragRetriever.Retrieve(ctx, userQuery)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 拼接参考知识
|
|
||||||
knowledge := ""
|
|
||||||
for i, doc := range docs {
|
|
||||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
|
|
||||||
// ==========================================
|
|
||||||
msgs, err := chatTpl.Format(ctx, map[string]any{
|
|
||||||
"history": sessionHistory,
|
|
||||||
"knowledge": knowledge,
|
|
||||||
"question": userQuery,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 5. 流式调用大模型生成客服回答
|
|
||||||
// ==========================================
|
|
||||||
fmt.Println("\n=== 客服回复 ===")
|
|
||||||
stream, err := chatModel.Stream(ctx, msgs)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fullReply := make([]*schema.Message, 0, 100)
|
|
||||||
for {
|
|
||||||
chunk, err := stream.Recv()
|
|
||||||
if errors.Is(err, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
fmt.Print(chunk.Content)
|
|
||||||
fullReply = append(fullReply, chunk)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 6. 拼接完整回复,存入 DB 作为新历史
|
|
||||||
// ==========================================
|
|
||||||
replyMsg, _ := schema.ConcatMessages(fullReply)
|
|
||||||
sessionHistory = append(sessionHistory,
|
|
||||||
&schema.Message{Role: schema.User, Content: userQuery},
|
|
||||||
replyMsg,
|
|
||||||
)
|
|
||||||
|
|
||||||
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// 本地客服提示词模板(不需要 MCP)
|
|
||||||
// ==========================================
|
|
||||||
func newCustomerServiceTemplate() prompt.ChatTemplate {
|
|
||||||
// 系统提示 + 多轮对话 + 知识库 + 用户问题
|
|
||||||
return prompt.FromMessages(schema.Messages{
|
|
||||||
{
|
|
||||||
Role: schema.System,
|
|
||||||
Content: `你是电商智能客服,语气友好简洁。
|
|
||||||
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
|
|
||||||
|
|
||||||
参考知识:
|
|
||||||
{{.knowledge}}`,
|
|
||||||
},
|
|
||||||
// 历史对话会自动渲染在这里
|
|
||||||
{{range .history}}{{.}},{{end}},
|
|
||||||
// 当前用户问题
|
|
||||||
{Role: schema.User, Content: "{{.question}}"},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==========================================
|
|
||||||
// PGVector 检索器(简化可直接用)
|
|
||||||
// ==========================================
|
|
||||||
type PGVectorRetriever struct {
|
|
||||||
topK int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPGVectorRetriever() retriever.Retriever {
|
|
||||||
return &PGVectorRetriever{topK: 3}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *PGVectorRetriever) Retrieve(
|
|
||||||
ctx context.Context,
|
|
||||||
query string,
|
|
||||||
opts ...retriever.Option,
|
|
||||||
) ([]*schema.Document, error) {
|
|
||||||
|
|
||||||
options := retriever.GetCommonOptions(nil, opts...)
|
|
||||||
topK := r.topK
|
|
||||||
if options.TopK != nil {
|
|
||||||
topK = *options.TopK
|
|
||||||
}
|
|
||||||
|
|
||||||
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
|
|
||||||
// 模拟知识库
|
|
||||||
return []*schema.Document{
|
|
||||||
{
|
|
||||||
ID: "1",
|
|
||||||
Content: "顺丰快递需要补10元运费差价",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "2",
|
|
||||||
Content: "订单满99元可免费升级顺丰",
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
107
common/eino/b.go
107
common/eino/b.go
@@ -1,107 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/elastic/go-elasticsearch/v8"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/indexer/es8"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
indexName = "eino_example"
|
|
||||||
fieldContent = "content"
|
|
||||||
fieldContentVector = "content_vector"
|
|
||||||
fieldExtraLocation = "location"
|
|
||||||
docExtraLocation = "location"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIndexer() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 1. 创建 ES 客户端
|
|
||||||
client, err := elasticsearch.NewClient(elasticsearch.Config{
|
|
||||||
Addresses: []string{"http://localhost:9200"},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("create client error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 定义 Index Spec(选填:如果索引不存在,将自动创建)
|
|
||||||
indexSpec := &es8.IndexSpec{
|
|
||||||
Settings: map[string]any{
|
|
||||||
"number_of_shards": 1,
|
|
||||||
"number_of_replicas": 0,
|
|
||||||
},
|
|
||||||
Mappings: map[string]any{
|
|
||||||
"properties": map[string]any{
|
|
||||||
fieldContentVector: map[string]any{
|
|
||||||
"type": "dense_vector",
|
|
||||||
"dims": 1024,
|
|
||||||
"index": true,
|
|
||||||
"similarity": "l2_norm",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 准备文档
|
|
||||||
// 文档通常包含 ID 和 Content
|
|
||||||
// 也可以包含额外的 Metadata 用于过滤或其他用途
|
|
||||||
docs := []*schema.Document{
|
|
||||||
{
|
|
||||||
ID: "1",
|
|
||||||
Content: "Eiffel Tower: Located in Paris, France.",
|
|
||||||
MetaData: map[string]any{
|
|
||||||
docExtraLocation: "France",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ID: "2",
|
|
||||||
Content: "The Great Wall: Located in China.",
|
|
||||||
MetaData: map[string]any{
|
|
||||||
docExtraLocation: "China",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 创建 ES 索引器组件
|
|
||||||
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
|
|
||||||
Client: client,
|
|
||||||
Index: indexName,
|
|
||||||
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
|
|
||||||
BatchSize: 10,
|
|
||||||
// DocumentToFields 指定如何将文档字段映射到 ES 字段
|
|
||||||
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
|
|
||||||
return map[string]es8.FieldValue{
|
|
||||||
fieldContent: {
|
|
||||||
Value: doc.Content,
|
|
||||||
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
|
|
||||||
},
|
|
||||||
fieldExtraLocation: {
|
|
||||||
// 额外的 metadata 字段
|
|
||||||
Value: doc.MetaData[docExtraLocation],
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
// 提供 embedding 组件用于向量化
|
|
||||||
Embedding: EmbedderDashscope,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("create indexer error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 6. 索引文档
|
|
||||||
ids, err := indexer.Store(ctx, docs)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("index error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Println("indexed ids:", ids)
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BaseTask 任务基类 - MongoDB版本
|
|
||||||
type BaseTask struct {
|
|
||||||
beans.MongoBaseDO `bson:",inline"`
|
|
||||||
// 任务信息
|
|
||||||
TaskType TaskType `bson:"taskType" json:"taskType"`
|
|
||||||
Status TaskStatus `bson:"status" json:"status"`
|
|
||||||
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
|
|
||||||
// 进度
|
|
||||||
TotalItems int64 `bson:"totalItems" json:"totalItems"`
|
|
||||||
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
|
|
||||||
Progress float64 `bson:"progress" json:"progress"`
|
|
||||||
// 结果
|
|
||||||
StartTime *time.Time `bson:"startTime" json:"startTime"`
|
|
||||||
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
|
|
||||||
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
|
|
||||||
SuccessCount int64 `bson:"successCount" json:"successCount"`
|
|
||||||
FailCount int64 `bson:"failCount" json:"failCount"`
|
|
||||||
// 其他
|
|
||||||
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLBaseTask 任务基类 - SQL版本
|
|
||||||
type SQLBaseTask struct {
|
|
||||||
beans.SQLBaseDO
|
|
||||||
// 任务信息
|
|
||||||
TaskType TaskType `json:"taskType"`
|
|
||||||
Status TaskStatus `json:"status"`
|
|
||||||
Priority TaskPriority `json:"priority,omitempty"`
|
|
||||||
// 进度
|
|
||||||
TotalItems int64 `json:"totalItems"`
|
|
||||||
ProcessedItems int64 `json:"processedItems"`
|
|
||||||
Progress float64 `json:"progress"`
|
|
||||||
// 结果
|
|
||||||
StartTime *time.Time `json:"startTime"`
|
|
||||||
EndTime *time.Time `json:"endTime,omitempty"`
|
|
||||||
Duration int64 `json:"duration,omitempty"`
|
|
||||||
SuccessCount int64 `json:"successCount"`
|
|
||||||
FailCount int64 `json:"failCount"`
|
|
||||||
// 其他
|
|
||||||
Executor string `json:"executor,omitempty"`
|
|
||||||
}
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"github.com/elastic/go-elasticsearch/v8"
|
|
||||||
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
|
||||||
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRetriever() {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
|
||||||
Addresses: []string{"http://localhost:9200"},
|
|
||||||
})
|
|
||||||
|
|
||||||
// 创建 retriever 组件
|
|
||||||
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
|
||||||
Client: client,
|
|
||||||
Index: indexName,
|
|
||||||
TopK: 5,
|
|
||||||
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
|
||||||
QueryFieldName: fieldContent,
|
|
||||||
VectorFieldName: fieldContentVector,
|
|
||||||
Hybrid: false,
|
|
||||||
// RRF 仅在特定许可证下可用
|
|
||||||
// 参见: https://www.elastic.co/subscriptions
|
|
||||||
RRF: false,
|
|
||||||
RRFRankConstant: nil,
|
|
||||||
RRFWindowSize: nil,
|
|
||||||
}),
|
|
||||||
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
|
||||||
doc = &schema.Document{
|
|
||||||
ID: *hit.Id_,
|
|
||||||
Content: "",
|
|
||||||
MetaData: map[string]any{},
|
|
||||||
}
|
|
||||||
|
|
||||||
var src map[string]any
|
|
||||||
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for field, val := range src {
|
|
||||||
switch field {
|
|
||||||
case fieldContent:
|
|
||||||
doc.Content = val.(string)
|
|
||||||
case fieldContentVector:
|
|
||||||
var v []float64
|
|
||||||
for _, item := range val.([]interface{}) {
|
|
||||||
v = append(v, item.(float64))
|
|
||||||
}
|
|
||||||
doc.WithDenseVector(v)
|
|
||||||
case fieldExtraLocation:
|
|
||||||
doc.MetaData[docExtraLocation] = val.(string)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if hit.Score_ != nil {
|
|
||||||
doc.WithScore(float64(*hit.Score_))
|
|
||||||
}
|
|
||||||
|
|
||||||
return doc, nil
|
|
||||||
},
|
|
||||||
Embedding: EmbedderDashscope,
|
|
||||||
})
|
|
||||||
|
|
||||||
// 不带过滤器的搜索
|
|
||||||
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
|
||||||
|
|
||||||
// 带过滤器的搜索
|
|
||||||
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
|
||||||
es8.WithFilters([]types.Query{{
|
|
||||||
Term: map[string]types.TermQuery{
|
|
||||||
fieldExtraLocation: {
|
|
||||||
CaseInsensitive: of(true),
|
|
||||||
Value: "China",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}}),
|
|
||||||
)
|
|
||||||
|
|
||||||
fmt.Printf("retrieved docs: %+v\n", docs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func of[T any](v T) *T {
|
|
||||||
return &v
|
|
||||||
}
|
|
||||||
125
common/eino/chat_model.go
Normal file
125
common/eino/chat_model.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package eino
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||||
|
"github.com/cloudwego/eino/components/prompt"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var globalChatModel *qwen.ChatModel
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||||
|
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||||
|
APIKey: apiKey,
|
||||||
|
Model: model,
|
||||||
|
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||||
|
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||||
|
TopP: gconv.PtrFloat32(1.0),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChatModel 只处理逻辑,不复用创建模型
|
||||||
|
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
|
||||||
|
// 1. 构建参考知识
|
||||||
|
knowledge, sources := buildKnowledgeAndSources(docs)
|
||||||
|
|
||||||
|
// 2. 构建提示词
|
||||||
|
msgs, err := buildPromptMessages(ctx, knowledge, content)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 🔥 直接使用全局单例,不重复创建
|
||||||
|
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
|
||||||
|
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
|
||||||
|
var knowledge string
|
||||||
|
var sources []string
|
||||||
|
|
||||||
|
for i, doc := range docs {
|
||||||
|
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||||
|
|
||||||
|
// 提取 document_id
|
||||||
|
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
|
||||||
|
sources = append(sources, gconv.String(docID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return knowledge, sources
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPromptMessages 构建提示词模板
|
||||||
|
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
|
||||||
|
promptTpl := prompt.FromMessages(
|
||||||
|
schema.FString,
|
||||||
|
&schema.Message{
|
||||||
|
Role: schema.System,
|
||||||
|
// Content: `你是专业的客服助手,语气友好。
|
||||||
|
//如果参考知识中有相关信息,请优先依据参考知识回答。
|
||||||
|
//如果没有相关信息,就正常回答,不要说无法回答。
|
||||||
|
//
|
||||||
|
//参考知识:
|
||||||
|
//{knowledge}`,
|
||||||
|
Content: `你是专业的客服助手,语气友好。
|
||||||
|
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
|
||||||
|
|
||||||
|
参考知识:
|
||||||
|
{knowledge}`,
|
||||||
|
},
|
||||||
|
&schema.Message{
|
||||||
|
Role: schema.User,
|
||||||
|
Content: "{question}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return promptTpl.Format(ctx, map[string]any{
|
||||||
|
"knowledge": knowledge,
|
||||||
|
"question": question,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// streamGenerateAnswer 流式生成
|
||||||
|
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
|
||||||
|
|
||||||
|
sr, err := chatModel.Stream(ctx, msgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stream failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []*schema.Message
|
||||||
|
for {
|
||||||
|
chunk, err := sr.Recv()
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stream recv failed: %w", err)
|
||||||
|
}
|
||||||
|
chunks = append(chunks, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema.ConcatMessages(chunks)
|
||||||
|
}
|
||||||
@@ -1,273 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright 2024 Red Future Authors
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package eino
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cloudwego/eino/callbacks"
|
|
||||||
"github.com/cloudwego/eino/components"
|
|
||||||
"github.com/cloudwego/eino/components/embedding"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
|
||||||
"github.com/gogf/gf/v2/net/gclient"
|
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// 千问API默认配置
|
|
||||||
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
|
||||||
defaultTimeout = 10 * time.Minute
|
|
||||||
defaultRetryTimes = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type QwenEmbeddingConfig struct {
|
|
||||||
// Timeout specifies the maximum duration to wait for API responses
|
|
||||||
// Optional. Default: 10 minutes
|
|
||||||
Timeout *time.Duration `json:"timeout"`
|
|
||||||
|
|
||||||
// HTTPClient specifies the client to send HTTP requests.
|
|
||||||
// Optional. Default &http.Client{Timeout: Timeout}
|
|
||||||
HTTPClient *http.Client `json:"http_client"`
|
|
||||||
|
|
||||||
// RetryTimes specifies the number of retry attempts for failed API calls
|
|
||||||
// Optional. Default: 2
|
|
||||||
RetryTimes *int `json:"retry_times"`
|
|
||||||
|
|
||||||
// BaseURL specifies the base URL for Qwen DashScope service
|
|
||||||
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
|
||||||
BaseURL string `json:"base_url"`
|
|
||||||
|
|
||||||
// APIKey specifies the API Key for authentication
|
|
||||||
// Required
|
|
||||||
APIKey string `json:"api_key"`
|
|
||||||
|
|
||||||
// Model specifies the model name for Qwen embedding
|
|
||||||
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
|
|
||||||
Model string `json:"model"`
|
|
||||||
|
|
||||||
// TextType specifies the type of text: "document" or "query"
|
|
||||||
// Optional. Default: "document"
|
|
||||||
TextType string `json:"text_type"`
|
|
||||||
|
|
||||||
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
|
|
||||||
// Optional. Default: 5
|
|
||||||
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QwenEmbedder struct {
|
|
||||||
client *gclient.Client
|
|
||||||
conf *QwenEmbeddingConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingRequest 千问embedding请求结构
|
|
||||||
type EmbeddingRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Input struct {
|
|
||||||
Texts []string `json:"texts"`
|
|
||||||
} `json:"input"`
|
|
||||||
Parameters struct {
|
|
||||||
TextType string `json:"text_type,omitempty"`
|
|
||||||
} `json:"parameters,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// EmbeddingResponse 千问embedding响应结构
|
|
||||||
type EmbeddingResponse struct {
|
|
||||||
Output struct {
|
|
||||||
Embeddings []struct {
|
|
||||||
TextIndex int `json:"text_index"`
|
|
||||||
Embedding []float64 `json:"embedding"`
|
|
||||||
} `json:"embeddings"`
|
|
||||||
} `json:"output"`
|
|
||||||
Usage struct {
|
|
||||||
TotalTokens int `json:"total_tokens"`
|
|
||||||
} `json:"usage"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type APIError struct {
|
|
||||||
Code string `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
RequestID string `json:"request_id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *APIError) Error() string {
|
|
||||||
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
|
|
||||||
if len(config.BaseURL) == 0 {
|
|
||||||
config.BaseURL = defaultBaseURL
|
|
||||||
}
|
|
||||||
if config.Timeout == nil {
|
|
||||||
config.Timeout = &defaultTimeout
|
|
||||||
}
|
|
||||||
if config.RetryTimes == nil {
|
|
||||||
defaultRetryTimes := 2
|
|
||||||
config.RetryTimes = &defaultRetryTimes
|
|
||||||
}
|
|
||||||
if len(config.TextType) == 0 {
|
|
||||||
config.TextType = "document"
|
|
||||||
}
|
|
||||||
if config.MaxConcurrentRequests == nil {
|
|
||||||
defaultMaxConcurrentRequests := 5
|
|
||||||
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
|
|
||||||
}
|
|
||||||
|
|
||||||
client := g.Client()
|
|
||||||
client.SetTimeout(*config.Timeout)
|
|
||||||
|
|
||||||
return client
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
|
|
||||||
if len(config.APIKey) == 0 {
|
|
||||||
return nil, fmt.Errorf("[Qwen] APIKey is required")
|
|
||||||
}
|
|
||||||
if len(config.Model) == 0 {
|
|
||||||
return nil, fmt.Errorf("[Qwen] Model is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := buildQwenClient(config)
|
|
||||||
|
|
||||||
return &QwenEmbedder{
|
|
||||||
client: client,
|
|
||||||
conf: config,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
|
|
||||||
[][]float64, error) {
|
|
||||||
|
|
||||||
if len(texts) == 0 {
|
|
||||||
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
options := embedding.GetCommonOptions(&embedding.Options{
|
|
||||||
Model: &e.conf.Model,
|
|
||||||
}, opts...)
|
|
||||||
|
|
||||||
conf := &embedding.Config{
|
|
||||||
Model: dereferenceOrZero(options.Model),
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
|
|
||||||
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
|
|
||||||
Texts: texts,
|
|
||||||
Config: conf,
|
|
||||||
})
|
|
||||||
defer func() {
|
|
||||||
if err := recover(); err != nil {
|
|
||||||
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var usage *embedding.TokenUsage
|
|
||||||
var embeddings [][]float64
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// 调用千问API获取embedding
|
|
||||||
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
|
|
||||||
if err != nil {
|
|
||||||
callbacks.OnError(ctx, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
|
|
||||||
Embeddings: embeddings,
|
|
||||||
Config: conf,
|
|
||||||
TokenUsage: usage,
|
|
||||||
})
|
|
||||||
|
|
||||||
return embeddings, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
|
|
||||||
// 构建请求
|
|
||||||
var req EmbeddingRequest
|
|
||||||
req.Model = e.conf.Model
|
|
||||||
req.Input.Texts = texts
|
|
||||||
req.Parameters.TextType = e.conf.TextType
|
|
||||||
|
|
||||||
// 调用API
|
|
||||||
client := e.client.Clone()
|
|
||||||
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
|
|
||||||
client.SetHeader("Content-Type", "application/json")
|
|
||||||
client.SetTimeout(*e.conf.Timeout)
|
|
||||||
|
|
||||||
resp, err := client.Post(ctx, e.conf.BaseURL, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Close()
|
|
||||||
|
|
||||||
// 检查状态码
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
var errResp APIError
|
|
||||||
result := resp.ReadAll()
|
|
||||||
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
|
|
||||||
return nil, nil, &errResp
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析响应
|
|
||||||
var apiResp EmbeddingResponse
|
|
||||||
result := resp.ReadAll()
|
|
||||||
if err = gconv.Struct(result, &apiResp); err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析响应结果
|
|
||||||
embeddings := make([][]float64, len(texts))
|
|
||||||
for _, emb := range apiResp.Output.Embeddings {
|
|
||||||
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
|
|
||||||
embeddings[emb.TextIndex] = emb.Embedding
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usage := &embedding.TokenUsage{
|
|
||||||
TotalTokens: apiResp.Usage.TotalTokens,
|
|
||||||
}
|
|
||||||
|
|
||||||
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
|
|
||||||
|
|
||||||
return embeddings, usage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *QwenEmbedder) GetType() string {
|
|
||||||
return getType()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func getType() string {
|
|
||||||
return "Qwen"
|
|
||||||
}
|
|
||||||
|
|
||||||
func dereferenceOrZero[T any](v *T) T {
|
|
||||||
if v == nil {
|
|
||||||
var t T
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
return *v
|
|
||||||
}
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
// TaskPriority 任务优先级
|
|
||||||
type TaskPriority string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskPriorityLow TaskPriority = "low" // 低优先级
|
|
||||||
TaskPriorityMedium TaskPriority = "medium" // 中优先级
|
|
||||||
TaskPriorityHigh TaskPriority = "high" // 高优先级
|
|
||||||
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
|
|
||||||
)
|
|
||||||
@@ -3,6 +3,8 @@ package eino
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"rag/dao"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/callbacks"
|
"github.com/cloudwego/eino/callbacks"
|
||||||
"github.com/cloudwego/eino/components/embedding"
|
"github.com/cloudwego/eino/components/embedding"
|
||||||
@@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct {
|
|||||||
Embedder embedding.Embedder
|
Embedder embedding.Embedder
|
||||||
DefaultTopK int
|
DefaultTopK int
|
||||||
DefaultIndex string
|
DefaultIndex string
|
||||||
|
DSLInfo map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
type PGVectorRetriever struct {
|
type PGVectorRetriever struct {
|
||||||
embedder embedding.Embedder
|
embedder embedding.Embedder
|
||||||
topK int
|
topK int
|
||||||
index string
|
index string
|
||||||
|
dslInfo map[string]any
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||||
@@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever,
|
|||||||
embedder: config.Embedder,
|
embedder: config.Embedder,
|
||||||
topK: config.DefaultTopK,
|
topK: config.DefaultTopK,
|
||||||
index: config.DefaultIndex,
|
index: config.DefaultIndex,
|
||||||
|
dslInfo: config.DSLInfo,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||||
|
|
||||||
// 1. 处理公共 Option(官方标准写法)
|
|
||||||
options := &retriever.Options{
|
options := &retriever.Options{
|
||||||
Index: &r.index,
|
Index: &r.index,
|
||||||
TopK: &r.topK,
|
TopK: &r.topK,
|
||||||
|
DSLInfo: r.dslInfo,
|
||||||
Embedding: r.embedder,
|
Embedding: r.embedder,
|
||||||
}
|
}
|
||||||
options = retriever.GetCommonOptions(options, opts...)
|
options = retriever.GetCommonOptions(options, opts...)
|
||||||
|
|
||||||
// 2. 回调(官方标准)
|
|
||||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||||
Query: query,
|
Query: query,
|
||||||
TopK: *options.TopK,
|
TopK: *options.TopK,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3. 执行检索
|
// ==========================================
|
||||||
docs, err := r.doRetrieve(ctx, query, options)
|
// 🔥 双路检索:向量 + 全文
|
||||||
|
// ==========================================
|
||||||
|
docsVector, err := r.doRetrieveVector(ctx, query, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
callbacks.OnError(ctx, err)
|
callbacks.OnError(ctx, err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 完成回调
|
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
|
||||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
if err != nil {
|
||||||
Docs: docs,
|
callbacks.OnError(ctx, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并 + 去重
|
||||||
|
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||||
|
|
||||||
|
// 排序(distance 越小越靠前)
|
||||||
|
sort.Slice(docs, func(i, j int) bool {
|
||||||
|
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||||
|
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||||
|
return d1 < d2
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 最多保留 topK
|
||||||
|
if len(docs) > *options.TopK {
|
||||||
|
docs = docs[:*options.TopK]
|
||||||
|
}
|
||||||
|
|
||||||
|
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
|
||||||
return docs, nil
|
return docs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
// ==========================================
|
||||||
|
// 1. 向量检索(PG)
|
||||||
// 1. 生成向量
|
// ==========================================
|
||||||
|
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
|
|||||||
return nil, errors.New("empty query vector")
|
return nil, errors.New("empty query vector")
|
||||||
}
|
}
|
||||||
|
|
||||||
queryVec := pgvector.NewVector(vectors[0])
|
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||||||
topK := *opts.TopK
|
topK := *opts.TopK
|
||||||
|
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||||
|
|
||||||
// 2. PG 向量相似度检索 SQL
|
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||||
sql := `
|
|
||||||
SELECT id, content, dataset_id, document_id,
|
|
||||||
vector <-> ? AS distance
|
|
||||||
FROM document_chunk
|
|
||||||
ORDER BY distance ASC
|
|
||||||
LIMIT ?
|
|
||||||
`
|
|
||||||
|
|
||||||
// 3. 查询
|
|
||||||
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 转为 Eino Document
|
|
||||||
docs := make([]*schema.Document, 0, len(rows))
|
docs := make([]*schema.Document, 0, len(rows))
|
||||||
for _, row := range rows {
|
for _, row := range rows {
|
||||||
docs = append(docs, &schema.Document{
|
docs = append(docs, &schema.Document{
|
||||||
ID: gconv.String(row["id"]),
|
ID: gconv.String(row["id"]),
|
||||||
Content: gconv.String(row["content"]),
|
Content: gconv.String(row["content"]),
|
||||||
Metadata: map[string]any{
|
MetaData: map[string]any{
|
||||||
"dataset_id": row["dataset_id"],
|
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||||
"document_id": row["document_id"],
|
"document_id": gconv.Int64(row["document_id"]),
|
||||||
"distance": row["distance"],
|
"distance": gconv.Float64(row["distance"]),
|
||||||
|
"retrieve_by": "vector",
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return docs, nil
|
return docs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 2. 全文检索(Meilisearch)🔥 新增
|
||||||
|
// ==========================================
|
||||||
|
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||||
|
topK := *opts.TopK
|
||||||
|
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||||
|
|
||||||
|
// 调用你已有的 Meilisearch DAO
|
||||||
|
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
docs := make([]*schema.Document, 0, len(rows))
|
||||||
|
for _, row := range rows {
|
||||||
|
docs = append(docs, &schema.Document{
|
||||||
|
ID: gconv.String(row["id"]),
|
||||||
|
Content: gconv.String(row["content"]),
|
||||||
|
MetaData: map[string]any{
|
||||||
|
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||||
|
"document_id": gconv.Int64(row["document_id"]),
|
||||||
|
"distance": 0.1, // 全文结果给高分
|
||||||
|
"retrieve_by": "fulltext",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return docs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 合并去重
|
||||||
|
// ==========================================
|
||||||
|
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||||||
|
idMap := make(map[string]*schema.Document)
|
||||||
|
for _, d := range vecDocs {
|
||||||
|
idMap[d.ID] = d
|
||||||
|
}
|
||||||
|
for _, d := range fullDocs {
|
||||||
|
if _, exists := idMap[d.ID]; !exists {
|
||||||
|
idMap[d.ID] = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
merged := make([]*schema.Document, 0, len(idMap))
|
||||||
|
for _, d := range idMap {
|
||||||
|
merged = append(merged, d)
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
// TaskStatus 任务状态
|
|
||||||
type TaskStatus string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskStatusPending TaskStatus = "pending" // 待处理
|
|
||||||
TaskStatusRunning TaskStatus = "running" // 运行中
|
|
||||||
TaskStatusCompleted TaskStatus = "completed" // 已完成
|
|
||||||
TaskStatusFailed TaskStatus = "failed" // 失败
|
|
||||||
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
|
|
||||||
)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
package eino
|
|
||||||
|
|
||||||
// TaskType 任务类型
|
|
||||||
type TaskType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
|
|
||||||
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
|
|
||||||
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
|
|
||||||
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
|
|
||||||
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
|
|
||||||
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
|
|
||||||
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
|
|
||||||
)
|
|
||||||
@@ -1,114 +0,0 @@
|
|||||||
package gse
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/go-ego/gse"
|
|
||||||
"github.com/go-ego/gse/hmm/extracker"
|
|
||||||
"github.com/go-ego/gse/hmm/segment"
|
|
||||||
"github.com/gogf/gf/v2/os/glog"
|
|
||||||
)
|
|
||||||
|
|
||||||
var GseTool *gseTool
|
|
||||||
|
|
||||||
// 初始化函数:程序启动时执行一次
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
GseTool, err = newGseTool()
|
|
||||||
if err != nil {
|
|
||||||
glog.Error(context.Background(), err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// gseTool 关键词提取工具(gse v1.0.2 标准)
|
|
||||||
type gseTool struct {
|
|
||||||
seg gse.Segmenter
|
|
||||||
tfidf *extracker.TagExtracter
|
|
||||||
tr *extracker.TextRanker
|
|
||||||
}
|
|
||||||
|
|
||||||
// newGseTool 初始化工具(内置词典 + 停用词)
|
|
||||||
func newGseTool() (tool *gseTool, err error) {
|
|
||||||
// 1. 初始化分词器
|
|
||||||
var seg gse.Segmenter
|
|
||||||
// 内置词典(无外部文件)
|
|
||||||
err = seg.LoadDictEmbed()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 内置停用词(v1.0.2 标准)
|
|
||||||
err = seg.LoadStopEmbed()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 初始化 TF-IDF 提取器
|
|
||||||
tfidf := &extracker.TagExtracter{}
|
|
||||||
tfidf.WithGse(seg)
|
|
||||||
err = tfidf.LoadIdf()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. 初始化 TextRank 提取器
|
|
||||||
tr := &extracker.TextRanker{}
|
|
||||||
tr.WithGse(seg)
|
|
||||||
|
|
||||||
tool = &gseTool{
|
|
||||||
seg: seg,
|
|
||||||
tfidf: tfidf,
|
|
||||||
tr: tr,
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM)
|
|
||||||
func (k *gseTool) Cut(text string) []string {
|
|
||||||
return k.seg.Cut(text, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keyword 最终输出:关键词 + 权重
|
|
||||||
type Keyword struct {
|
|
||||||
Word string `json:"word"`
|
|
||||||
Score float64 `json:"score"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (k *gseTool) Extract(text string, topN int) []Keyword {
|
|
||||||
// 1. 提取 TF-IDF
|
|
||||||
tfTags := k.extractTFIDF(text, topN)
|
|
||||||
|
|
||||||
// 2. 提取 TextRank
|
|
||||||
trTags := k.extractTextRank(text, topN)
|
|
||||||
|
|
||||||
// 3. 合并成最终关键词(业务最常用)
|
|
||||||
scoreMap := make(map[string]float64)
|
|
||||||
for _, tag := range tfTags {
|
|
||||||
scoreMap[tag.Text] = tag.Weight
|
|
||||||
}
|
|
||||||
for _, tag := range trTags {
|
|
||||||
scoreMap[tag.Text] = tag.Weight
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转成切片并排序(高分在前)
|
|
||||||
res := make([]Keyword, 0, len(scoreMap))
|
|
||||||
for word, score := range scoreMap {
|
|
||||||
res = append(res, Keyword{Word: word, Score: score})
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(res, func(i, j int) bool {
|
|
||||||
return res[i].Score > res[j].Score
|
|
||||||
})
|
|
||||||
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractTFIDF TF-IDF 关键词(带权重)90% 业务:文章标签、搜索、关键词
|
|
||||||
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
|
|
||||||
return k.tfidf.ExtractTags(text, topN)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
|
|
||||||
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
|
|
||||||
return k.tr.TextRank(text, topN)
|
|
||||||
}
|
|
||||||
69
common/task/base_task.go
Normal file
69
common/task/base_task.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package task
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
|
)
|
||||||
|
|
||||||
|
type baseTaskCol struct {
|
||||||
|
beans.SQLBaseCol
|
||||||
|
TaskType string
|
||||||
|
Status string
|
||||||
|
Priority string
|
||||||
|
ParentTaskID string
|
||||||
|
TotalItems string
|
||||||
|
ProcessedItems string
|
||||||
|
Progress string
|
||||||
|
StartTime string
|
||||||
|
EndTime string
|
||||||
|
Duration string
|
||||||
|
SuccessCount string
|
||||||
|
FailCount string
|
||||||
|
Executor string
|
||||||
|
DocumentID string
|
||||||
|
Remark string
|
||||||
|
}
|
||||||
|
|
||||||
|
var BaseTaskCol = baseTaskCol{
|
||||||
|
SQLBaseCol: beans.DefSQLBaseCol,
|
||||||
|
TaskType: "task_type",
|
||||||
|
Status: "status",
|
||||||
|
Priority: "task_priority",
|
||||||
|
ParentTaskID: "parent_task_id",
|
||||||
|
TotalItems: "total_items",
|
||||||
|
ProcessedItems: "processed_items",
|
||||||
|
Progress: "progress",
|
||||||
|
StartTime: "start_time",
|
||||||
|
EndTime: "end_time",
|
||||||
|
Duration: "duration",
|
||||||
|
SuccessCount: "success_count",
|
||||||
|
FailCount: "fail_count",
|
||||||
|
Executor: "executor",
|
||||||
|
DocumentID: "document_id",
|
||||||
|
Remark: "remark",
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLBaseTask 任务基类 - SQL版本
|
||||||
|
type SQLBaseTask struct {
|
||||||
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
|
// 任务核心信息
|
||||||
|
TaskType TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||||
|
Status TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||||
|
Priority TaskPriority `orm:"task_priority" json:"priority,omitempty" dc:"任务优先级"`
|
||||||
|
ParentTaskID int64 `orm:"parent_task_id" json:"parentTaskId,omitempty" dc:"父任务ID"`
|
||||||
|
// 任务进度
|
||||||
|
TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总数"`
|
||||||
|
ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理数"`
|
||||||
|
Progress float64 `orm:"progress" json:"progress" dc:"进度"` // 0~100 百分比
|
||||||
|
// 任务结果
|
||||||
|
StartTime *time.Time `orm:"start_time" json:"startTime" dc:"开始时间"`
|
||||||
|
EndTime *time.Time `orm:"end_time" json:"endTime,omitempty" dc:"结束时间"`
|
||||||
|
Duration int64 `orm:"duration" json:"duration,omitempty" dc:"耗时(毫秒)"`
|
||||||
|
SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||||
|
FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||||
|
// 其他
|
||||||
|
Executor string `orm:"executor" json:"executor,omitempty" dc:"执行器标识"`
|
||||||
|
DocumentID int64 `orm:"document_id" json:"documentId,omitempty" dc:"文档ID"`
|
||||||
|
Remark string `orm:"remark" json:"remark,omitempty" dc:"备注/错误信息"`
|
||||||
|
}
|
||||||
30
common/task/consts.go
Normal file
30
common/task/consts.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package task
|
||||||
|
|
||||||
|
// TaskType 任务类型枚举:文档解析的三个子任务
|
||||||
|
type TaskType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词
|
||||||
|
TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量
|
||||||
|
TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索
|
||||||
|
TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskStatus 任务状态枚举
|
||||||
|
type TaskStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskStatusPending TaskStatus = "PENDING" // 待执行
|
||||||
|
TaskStatusRunning TaskStatus = "RUNNING" // 执行中
|
||||||
|
TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成
|
||||||
|
TaskStatusFailed TaskStatus = "FAILED" // 执行失败
|
||||||
|
)
|
||||||
|
|
||||||
|
// TaskPriority 任务优先级
|
||||||
|
type TaskPriority int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TaskPriorityLow TaskPriority = 1 // 低
|
||||||
|
TaskPriorityMedium TaskPriority = 2 // 中
|
||||||
|
TaskPriorityHigh TaskPriority = 3 // 高
|
||||||
|
)
|
||||||
24
config.yml
24
config.yml
@@ -48,10 +48,10 @@ database:
|
|||||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||||
rag_knowledge:
|
rag_knowledge:
|
||||||
- type: "pgsql"
|
- type: "pgsql"
|
||||||
host: "localhost"
|
host: "116.204.74.41"
|
||||||
port: "5432"
|
port: "15432"
|
||||||
user: "postgres"
|
user: "postgres"
|
||||||
pass: "123456"
|
pass: "Bjang09@686^*^"
|
||||||
name: "tenant-1"
|
name: "tenant-1"
|
||||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||||
role: "master"
|
role: "master"
|
||||||
@@ -69,10 +69,10 @@ database:
|
|||||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||||
rag_vector:
|
rag_vector:
|
||||||
- type: "pgsql"
|
- type: "pgsql"
|
||||||
host: "localhost"
|
host: "116.204.74.41"
|
||||||
port: "5432"
|
port: "15432"
|
||||||
user: "postgres"
|
user: "postgres"
|
||||||
pass: "123456"
|
pass: "Bjang09@686^*^"
|
||||||
name: "tenant-1"
|
name: "tenant-1"
|
||||||
prefix: "rag_vector_" # (可选)表名前缀
|
prefix: "rag_vector_" # (可选)表名前缀
|
||||||
role: "master"
|
role: "master"
|
||||||
@@ -91,14 +91,14 @@ database:
|
|||||||
|
|
||||||
redis:
|
redis:
|
||||||
default:
|
default:
|
||||||
address: "localhost:6379"
|
address: "116.204.74.41:6379"
|
||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
consul:
|
consul:
|
||||||
address: localhost:8500
|
address: 116.204.74.41:8500
|
||||||
|
|
||||||
jaeger:
|
jaeger:
|
||||||
addr: localhost:4318
|
addr: 116.204.74.41:4318
|
||||||
|
|
||||||
# eino框架配置
|
# eino框架配置
|
||||||
eino:
|
eino:
|
||||||
@@ -115,6 +115,10 @@ eino:
|
|||||||
# apiType: "multi_modal_api"
|
# apiType: "multi_modal_api"
|
||||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||||
model: "text-embedding-v3"
|
model: "text-embedding-v3"
|
||||||
|
chatmodel:
|
||||||
|
provider: "dashscope"
|
||||||
|
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||||
|
model: "qwen-turbo"
|
||||||
|
|
||||||
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
||||||
filePrefix: "http://116.204.74.41:9000"
|
filePrefix: "http://116.204.74.41:9000"
|
||||||
@@ -122,7 +126,7 @@ filePrefix: "http://116.204.74.41:9000"
|
|||||||
gmq:
|
gmq:
|
||||||
redis:
|
redis:
|
||||||
primary:
|
primary:
|
||||||
addr: "localhost"
|
addr: "116.204.74.41"
|
||||||
port: "6379"
|
port: "6379"
|
||||||
db: 0
|
db: 0
|
||||||
username: ""
|
username: ""
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package public
|
package public
|
||||||
|
|
||||||
|
// 数据库名称
|
||||||
const (
|
const (
|
||||||
DbNameKnowledge = "rag_knowledge"
|
DbNameKnowledge = "rag_knowledge"
|
||||||
DbNameVector = "rag_vector"
|
DbNameVector = "rag_vector"
|
||||||
@@ -10,6 +11,7 @@ const (
|
|||||||
TableNameDocument = "document"
|
TableNameDocument = "document"
|
||||||
TableNameDataset = "dataset"
|
TableNameDataset = "dataset"
|
||||||
TableNameKeyword = "keyword"
|
TableNameKeyword = "keyword"
|
||||||
|
TableNameTask = "task"
|
||||||
TableNameDatasetIndex = "dataset_index"
|
TableNameDatasetIndex = "dataset_index"
|
||||||
TableNameDocumentChunk = "document_chunk"
|
TableNameDocumentChunk = "document_chunk"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (c *document) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process 处理文件(向量化)
|
// Process 处理文件(向量化)
|
||||||
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *beans.ResponseEmpty, err error) {
|
||||||
res, err = service.Document.Process(ctx, req)
|
err = service.Document.Process(ctx, req)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
17
controller/rag_query.go
Normal file
17
controller/rag_query.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"rag/model/dto"
|
||||||
|
"rag/service"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ragQuery struct{}
|
||||||
|
|
||||||
|
var RAGQuery = new(ragQuery)
|
||||||
|
|
||||||
|
// Query 执行RAG查询
|
||||||
|
func (c *ragQuery) Query(ctx context.Context, req *dto.RAGQueryReq) (res *dto.RAGQueryRes, err error) {
|
||||||
|
res, err = service.RAGQuery.Query(ctx, req)
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -49,6 +49,7 @@ func (d *datasetIndexDao) InsertIndex(ctx context.Context, indexName string) (er
|
|||||||
CREATE INDEX IF NOT EXISTS %s
|
CREATE INDEX IF NOT EXISTS %s
|
||||||
ON %s
|
ON %s
|
||||||
USING ivfflat (vector vector_cosine_ops)
|
USING ivfflat (vector vector_cosine_ops)
|
||||||
|
WITH (lists = 100)
|
||||||
WHERE vector IS NOT NULL;
|
WHERE vector IS NOT NULL;
|
||||||
`, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk)
|
`, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk)
|
||||||
_, err = db.Exec(ctx, sqlStr)
|
_, err = db.Exec(ctx, sqlStr)
|
||||||
|
|||||||
@@ -2,12 +2,17 @@ package dao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"rag/consts/public"
|
"rag/consts/public"
|
||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||||
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
|
"github.com/gogf/gf/v2/text/gstr"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
"github.com/pgvector/pgvector-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DocumentChunk = new(documentChunkDao)
|
var DocumentChunk = new(documentChunkDao)
|
||||||
@@ -55,3 +60,56 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR
|
|||||||
err = r.Structs(&res)
|
err = r.Structs(&res)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||||
|
sql := `
|
||||||
|
SELECT id, content, dataset_id, document_id,
|
||||||
|
vector <-> ? AS distance
|
||||||
|
FROM rag_vector_document_chunk
|
||||||
|
WHERE dataset_id IN (?)
|
||||||
|
AND vector IS NOT NULL
|
||||||
|
ORDER BY distance ASC
|
||||||
|
LIMIT ?
|
||||||
|
`
|
||||||
|
// 顺序:vector, dataset_id, topK
|
||||||
|
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryVec, datasetId, topK)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.List(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchByKeywords 通过关键词全文检索文档块
|
||||||
|
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||||||
|
// 构建 meilisearch 查询参数
|
||||||
|
searchParams := &meilisearch.SearchParams{
|
||||||
|
Query: query,
|
||||||
|
Limit: int64(topK),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建 datasetIds 过滤条件
|
||||||
|
if len(datasetIds) > 0 {
|
||||||
|
datasetIdStrs := gconv.Strings(datasetIds)
|
||||||
|
quotedIds := make([]string, len(datasetIdStrs))
|
||||||
|
for i, id := range datasetIdStrs {
|
||||||
|
quotedIds[i] = fmt.Sprintf("%s", id)
|
||||||
|
}
|
||||||
|
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行搜索
|
||||||
|
var hits []map[string]interface{}
|
||||||
|
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换查询结果为 gdb.List
|
||||||
|
resultList := make(gdb.List, 0, len(hits))
|
||||||
|
for _, hit := range hits {
|
||||||
|
resultList = append(resultList, hit)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultList, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -82,6 +82,9 @@ func (d *keywordDao) List(ctx context.Context, req *dto.ListKeywordReq, fields .
|
|||||||
if !g.IsEmpty(req.Keyword) {
|
if !g.IsEmpty(req.Keyword) {
|
||||||
model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%")
|
model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%")
|
||||||
}
|
}
|
||||||
|
model.WhereIn(entity.KeywordCol.Word, req.Words)
|
||||||
|
model.Where(entity.KeywordCol.DatasetId, req.DatasetId)
|
||||||
|
model.Where(entity.KeywordCol.DocumentId, req.DocumentId)
|
||||||
model.OrderDesc(entity.KeywordCol.Weight)
|
model.OrderDesc(entity.KeywordCol.Weight)
|
||||||
model.OrderDesc(entity.KeywordCol.CreatedAt)
|
model.OrderDesc(entity.KeywordCol.CreatedAt)
|
||||||
if req.Page != nil {
|
if req.Page != nil {
|
||||||
|
|||||||
58
dao/task.go
Normal file
58
dao/task.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package dao
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"rag/consts/public"
|
||||||
|
"rag/model/dto"
|
||||||
|
"rag/model/entity"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Task = new(taskDao)
|
||||||
|
|
||||||
|
type taskDao struct{}
|
||||||
|
|
||||||
|
// Insert 创建任务
|
||||||
|
func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, err error) {
|
||||||
|
var res *entity.Task
|
||||||
|
if err = gconv.Struct(req, &res); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Data(&res).Insert()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return r.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update 更新任务
|
||||||
|
func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) {
|
||||||
|
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask)
|
||||||
|
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return r.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *taskDao) Get(ctx context.Context, req *dto.GetTaskReq) (res []*entity.Task, total int, err error) {
|
||||||
|
r, total, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||||
|
Where(entity.TaskCol.Id, req.Id).
|
||||||
|
Where(entity.TaskCol.TaskId, req.TaskId).
|
||||||
|
Where(entity.TaskCol.TaskType, req.TaskType).AllAndCount(false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = r.Structs(&res)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *taskDao) DeleteByTaskId(ctx context.Context, req *dto.DeleteTaskByTaskIdReq) (rows int64, err error) {
|
||||||
|
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Where(entity.TaskCol.TaskId, req.TaskId).Delete()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return r.RowsAffected()
|
||||||
|
}
|
||||||
9
go.mod
9
go.mod
@@ -16,9 +16,9 @@ require (
|
|||||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
|
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
|
||||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
|
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
|
||||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9
|
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||||
|
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7
|
||||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9
|
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9
|
||||||
github.com/elastic/go-elasticsearch/v8 v8.16.0
|
github.com/elastic/go-elasticsearch/v8 v8.16.0
|
||||||
github.com/go-ego/gse v1.0.2
|
|
||||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||||
github.com/gogf/gf/v2 v2.10.0
|
github.com/gogf/gf/v2 v2.10.0
|
||||||
github.com/golang/glog v1.2.5
|
github.com/golang/glog v1.2.5
|
||||||
@@ -50,7 +50,7 @@ require (
|
|||||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 // indirect
|
||||||
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
||||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
@@ -64,6 +64,7 @@ require (
|
|||||||
github.com/fatih/color v1.19.0 // indirect
|
github.com/fatih/color v1.19.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||||
|
github.com/go-ego/gse v1.0.2 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
@@ -105,7 +106,7 @@ require (
|
|||||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 // indirect
|
github.com/meguminnnnnnnnn/go-openai v0.1.2 // indirect
|
||||||
github.com/meilisearch/meilisearch-go v0.36.1 // indirect
|
github.com/meilisearch/meilisearch-go v0.36.1 // indirect
|
||||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||||
@@ -134,7 +135,7 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.199 // indirect
|
github.com/volcengine/volc-sdk-golang v1.0.199 // indirect
|
||||||
github.com/volcengine/volcengine-go-sdk v1.0.181 // indirect
|
github.com/volcengine/volcengine-go-sdk v1.2.9 // indirect
|
||||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||||
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect
|
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect
|
||||||
github.com/xuri/excelize/v2 v2.9.0 // indirect
|
github.com/xuri/excelize/v2 v2.9.0 // indirect
|
||||||
|
|||||||
17
go.sum
17
go.sum
@@ -33,8 +33,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
|
|||||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||||
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ=
|
||||||
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM=
|
||||||
gitea.com/red-future/common v0.0.11 h1:AV7W3G0uZ8aPpHHSHd4ZHmLWe5+2STPKe/AYPoPCWVc=
|
|
||||||
gitea.com/red-future/common v0.0.11/go.mod h1:B8syUI4XbLCDQSeRHURYxEwnWw8mEFgmqCxjC+lM+NU=
|
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||||
@@ -158,10 +156,12 @@ github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-
|
|||||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419/go.mod h1:SajSFFRIXJXIbxadAAlSUIS5KTY8R/jzJg9RNSOXCCI=
|
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419/go.mod h1:SajSFFRIXJXIbxadAAlSUIS5KTY8R/jzJg9RNSOXCCI=
|
||||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 h1:vZ3dL8xwo2sy73aBVKs4AJiO5OCHRxMOJUwIYkp0CWs=
|
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 h1:vZ3dL8xwo2sy73aBVKs4AJiO5OCHRxMOJUwIYkp0CWs=
|
||||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:+oI0sr0rA0OHCxaQJ0rzMYld3LAODHhPKzBx5JYCya0=
|
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:+oI0sr0rA0OHCxaQJ0rzMYld3LAODHhPKzBx5JYCya0=
|
||||||
|
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7 h1:8c1LB5lH+dERbf2twp18B1Y822JOQSsS6x7Vnksehk0=
|
||||||
|
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7/go.mod h1:n4iuIUQeL3D8GRsGAhkgceRZpoyPQbqOXFMXM2Q4hNY=
|
||||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 h1:Sl6giB1SJlA+ZlO0gzPH05IsUORtdYYPN6GiyH1B9MA=
|
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 h1:Sl6giB1SJlA+ZlO0gzPH05IsUORtdYYPN6GiyH1B9MA=
|
||||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:H4kNmiTe2irnvipVNIP4q8yqXf2fZ6v24krvQYBtYb8=
|
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:H4kNmiTe2irnvipVNIP4q8yqXf2fZ6v24krvQYBtYb8=
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 h1:yOZII6VYaL00CVZYba+HUixFygsW0Xz/1QjQ5htj1Ls=
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 h1:LbdSG9+qWzzp9RFW6dSFkaUW171JvCoYn/K63zX6dQE=
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14/go.mod h1:1xMQZ8eE11pkEoTAEy8UlaAY817qGVMvjpDPGSIO3Ns=
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0=
|
||||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||||
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
@@ -531,8 +531,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
|||||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 h1:u/IMMgrj/d617Dh/8BKAwlcstD74ynOJzCtVl+y8xAs=
|
github.com/meguminnnnnnnnn/go-openai v0.1.2 h1:iXombGGjqjBrmE9WaSidUhhi3YQhf42QTHvHLMkgvCA=
|
||||||
github.com/meguminnnnnnnnn/go-openai v0.1.1/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
github.com/meguminnnnnnnnn/go-openai v0.1.2/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
||||||
github.com/meilisearch/meilisearch-go v0.36.1 h1:mJTCJE5g7tRvaqKco6DfqOuJEjX+rRltDEnkEC02Y0M=
|
github.com/meilisearch/meilisearch-go v0.36.1 h1:mJTCJE5g7tRvaqKco6DfqOuJEjX+rRltDEnkEC02Y0M=
|
||||||
github.com/meilisearch/meilisearch-go v0.36.1/go.mod h1:hWcR0MuWLSzHfbz9GGzIr3s9rnXLm1jqkmHkJPbUSvM=
|
github.com/meilisearch/meilisearch-go v0.36.1/go.mod h1:hWcR0MuWLSzHfbz9GGzIr3s9rnXLm1jqkmHkJPbUSvM=
|
||||||
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
|
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
|
||||||
@@ -735,8 +735,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV
|
|||||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.199 h1:zv9QOqTl/IsLwtfC37GlJtcz6vMAHi+pjq8ILWjLYUc=
|
github.com/volcengine/volc-sdk-golang v1.0.199 h1:zv9QOqTl/IsLwtfC37GlJtcz6vMAHi+pjq8ILWjLYUc=
|
||||||
github.com/volcengine/volc-sdk-golang v1.0.199/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
|
github.com/volcengine/volc-sdk-golang v1.0.199/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
|
||||||
github.com/volcengine/volcengine-go-sdk v1.0.181 h1:/3PB4M1N4fjMqiSKTJwX43EZ5Nn1HUOtQrSCk+22+wI=
|
github.com/volcengine/volcengine-go-sdk v1.2.9 h1:du2gnImtyWXKkQFnJW/GXCs+UBibGGOXIbP1Ams2pB8=
|
||||||
github.com/volcengine/volcengine-go-sdk v1.0.181/go.mod h1:gfEDc1s7SYaGoY+WH2dRrS3qiuDJMkwqyfXWCa7+7oA=
|
github.com/volcengine/volcengine-go-sdk v1.2.9/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
|
||||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||||
@@ -1193,6 +1193,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
|||||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||||
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
|
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||||
|
|||||||
20
main.go
20
main.go
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.com/red-future/common/http"
|
||||||
"gitea.com/red-future/common/jaeger"
|
"gitea.com/red-future/common/jaeger"
|
||||||
|
"gitea.com/red-future/common/utils"
|
||||||
gmq "github.com/bjang03/gmq/core/gmq"
|
gmq "github.com/bjang03/gmq/core/gmq"
|
||||||
"github.com/bjang03/gmq/mq"
|
"github.com/bjang03/gmq/mq"
|
||||||
"github.com/bjang03/gmq/types"
|
"github.com/bjang03/gmq/types"
|
||||||
@@ -27,22 +28,17 @@ func main() {
|
|||||||
controller.Dataset,
|
controller.Dataset,
|
||||||
controller.Document,
|
controller.Document,
|
||||||
controller.DocumentChunk,
|
controller.DocumentChunk,
|
||||||
|
controller.Keyword,
|
||||||
|
controller.RAGQuery,
|
||||||
})
|
})
|
||||||
|
|
||||||
gmq.Init("config.yml")
|
err := utils.InitGseTool(ctx)
|
||||||
|
if err != nil {
|
||||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
g.Log().Error(ctx, "gse 分词工具初始化失败:", err)
|
||||||
SubMessage: types.SubMessage{
|
|
||||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
|
||||||
ConsumerName: public.KnowledgeDocumentVectorStatusConsumer,
|
|
||||||
AutoAck: public.KnowledgeDocumentVectorStatusAutoAck,
|
|
||||||
FetchCount: public.KnowledgeDocumentVectorStatusBatchSize,
|
|
||||||
HandleFunc: service.Document.DocsVectorStatusMsg,
|
|
||||||
},
|
|
||||||
}); err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gmq.Init("config.yml")
|
||||||
|
|
||||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||||
SubMessage: types.SubMessage{
|
SubMessage: types.SubMessage{
|
||||||
Topic: public.KnowledgeDocumentChunkTopic,
|
Topic: public.KnowledgeDocumentChunkTopic,
|
||||||
|
|||||||
@@ -84,12 +84,6 @@ type ProcessDocumentReq struct {
|
|||||||
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessDocumentRes 处理文件响应
|
|
||||||
type ProcessDocumentRes struct {
|
|
||||||
ChunkCount int64 `json:"chunkCount"`
|
|
||||||
CostTime int64 `json:"costTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ListDocumentChunkRPC struct {
|
type ListDocumentChunkRPC struct {
|
||||||
List []*DocumentChunkRPC `json:"list"`
|
List []*DocumentChunkRPC `json:"list"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ type ListKeywordReq struct {
|
|||||||
DatasetId int64 `json:"datasetId"`
|
DatasetId int64 `json:"datasetId"`
|
||||||
DocumentId int64 `json:"documentId"`
|
DocumentId int64 `json:"documentId"`
|
||||||
Word string `json:"word"`
|
Word string `json:"word"`
|
||||||
|
Words []string `json:"words"`
|
||||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,9 +63,11 @@ type ListKeywordRes struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type KeywordVO struct {
|
type KeywordVO struct {
|
||||||
Id int64 `json:"id,string" dc:"id"`
|
Id int64 `json:"id,string" dc:"id"`
|
||||||
Word string `json:"word" dc:"关键词名称"`
|
Word string `json:"word" dc:"关键词名称"`
|
||||||
Weight int16 `json:"weight" dc:"权重"`
|
Weight int16 `json:"weight" dc:"权重"`
|
||||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
|
||||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
|
||||||
|
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||||
|
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||||
}
|
}
|
||||||
|
|||||||
21
model/dto/rag_query.go
Normal file
21
model/dto/rag_query.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RAGQueryReq RAG查询请求
|
||||||
|
type RAGQueryReq struct {
|
||||||
|
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
|
||||||
|
|
||||||
|
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
|
||||||
|
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
|
||||||
|
TopK int `json:"topK" d:"5" dc:"检索topK,默认5"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RAGQueryRes RAG查询响应
|
||||||
|
type RAGQueryRes struct {
|
||||||
|
Answer string `json:"answer" dc:"生成的答案"`
|
||||||
|
DatasetId string `json:"datasetId" dc:"使用的数据集ID"`
|
||||||
|
Sources []string `json:"sources" dc:"参考来源"`
|
||||||
|
}
|
||||||
65
model/dto/task.go
Normal file
65
model/dto/task.go
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"rag/common/task"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WriteTaskProgressReq 写入任务进度请求
|
||||||
|
type WriteTaskProgressReq struct {
|
||||||
|
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||||
|
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||||
|
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||||
|
Remark string `json:"remark" dc:"备注"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTaskReq 创建任务请求
|
||||||
|
type CreateTaskReq struct {
|
||||||
|
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||||
|
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||||
|
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||||
|
Remark string `json:"remark" dc:"备注"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTaskReq 更新任务请求
|
||||||
|
type UpdateTaskReq struct {
|
||||||
|
Id int64 `json:"id" dc:"任务ID"`
|
||||||
|
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||||
|
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||||
|
Remark string `json:"remark" dc:"备注"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTaskByTaskIdReq 删除任务请求
|
||||||
|
type DeleteTaskByTaskIdReq struct {
|
||||||
|
TaskId int64 `json:"taskId" v:"required#任务id不能为空"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTaskReq 获取任务请求
|
||||||
|
type GetTaskReq struct {
|
||||||
|
Id int64 `json:"id" dc:"任务ID"`
|
||||||
|
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||||
|
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TaskVO 任务视图对象
|
||||||
|
type TaskVO struct {
|
||||||
|
Id int64 `json:"id" dc:"任务ID"`
|
||||||
|
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||||
|
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||||
|
Priority task.TaskPriority `json:"priority" dc:"任务优先级"`
|
||||||
|
ParentTaskID int64 `json:"parentTaskId" dc:"父任务ID"`
|
||||||
|
TotalItems int64 `json:"totalItems" dc:"总项数"`
|
||||||
|
ProcessedItems int64 `json:"processedItems" dc:"已处理项数"`
|
||||||
|
Progress float64 `json:"progress" dc:"进度百分比"`
|
||||||
|
StartTime *int64 `json:"startTime" dc:"开始时间戳"`
|
||||||
|
EndTime *int64 `json:"endTime" dc:"结束时间戳"`
|
||||||
|
Duration int64 `json:"duration" dc:"耗时(毫秒)"`
|
||||||
|
SuccessCount int64 `json:"successCount" dc:"成功数"`
|
||||||
|
FailCount int64 `json:"failCount" dc:"失败数"`
|
||||||
|
Executor string `json:"executor" dc:"执行器"`
|
||||||
|
DocumentID int64 `json:"documentId" dc:"文档ID"`
|
||||||
|
Remark string `json:"remark" dc:"备注"`
|
||||||
|
Creator string `json:"creator" dc:"创建人"`
|
||||||
|
CreatedAt int64 `json:"createdAt" dc:"创建时间"`
|
||||||
|
Updater string `json:"updater" dc:"更新人"`
|
||||||
|
UpdatedAt int64 `json:"updatedAt" dc:"更新时间"`
|
||||||
|
}
|
||||||
66
model/entity/task.go
Normal file
66
model/entity/task.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package entity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"rag/common/task"
|
||||||
|
|
||||||
|
"gitea.com/red-future/common/beans"
|
||||||
|
)
|
||||||
|
|
||||||
|
type taskCol struct {
|
||||||
|
beans.SQLBaseCol
|
||||||
|
TaskId string
|
||||||
|
TaskType string
|
||||||
|
Status string
|
||||||
|
Executor string
|
||||||
|
Remark string
|
||||||
|
//Priority string
|
||||||
|
//ParentTaskId string
|
||||||
|
//TotalItems string
|
||||||
|
//ProcessedItems string
|
||||||
|
//Progress string
|
||||||
|
//StartTime string
|
||||||
|
//EndTime string
|
||||||
|
//Duration string
|
||||||
|
//SuccessCount string
|
||||||
|
//FailCount string
|
||||||
|
}
|
||||||
|
|
||||||
|
var TaskCol = taskCol{
|
||||||
|
SQLBaseCol: beans.DefSQLBaseCol,
|
||||||
|
TaskId: "task_id",
|
||||||
|
TaskType: "task_type",
|
||||||
|
Status: "status",
|
||||||
|
Executor: "executor",
|
||||||
|
Remark: "remark",
|
||||||
|
//Priority: "priority",
|
||||||
|
//ParentTaskId: "parent_task_id",
|
||||||
|
//TotalItems: "total_items",
|
||||||
|
//ProcessedItems: "processed_items",
|
||||||
|
//Progress: "progress",
|
||||||
|
//StartTime: "start_time",
|
||||||
|
//EndTime: "end_time",
|
||||||
|
//Duration: "duration",
|
||||||
|
//SuccessCount: "success_count",
|
||||||
|
//FailCount: "fail_count",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task 任务记录表
|
||||||
|
type Task struct {
|
||||||
|
beans.SQLBaseDO `orm:",inline"`
|
||||||
|
|
||||||
|
TaskId int64 `orm:"task_id" json:"taskId" dc:"任务ID"`
|
||||||
|
TaskType task.TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||||
|
Status task.TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||||
|
Executor string `orm:"executor" json:"executor" dc:"执行器"`
|
||||||
|
Remark string `orm:"remark" json:"remark" dc:"备注"`
|
||||||
|
//Priority task.TaskPriority `orm:"priority" json:"priority" dc:"任务优先级"`
|
||||||
|
//ParentTaskId int64 `orm:"parent_task_id" json:"parentTaskId" dc:"父任务ID"`
|
||||||
|
//TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总项数"`
|
||||||
|
//ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理项数"`
|
||||||
|
//SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||||
|
//FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||||
|
//Progress float64 `orm:"progress" json:"progress" dc:"进度百分比"`
|
||||||
|
//StartTime *gtime.Time `orm:"start_time" json:"startTime" dc:"开始时间戳"`
|
||||||
|
//EndTime *gtime.Time `orm:"end_time" json:"endTime" dc:"结束时间戳"`
|
||||||
|
//Duration int64 `orm:"duration" json:"duration" dc:"耗时(毫秒)"`
|
||||||
|
}
|
||||||
@@ -5,17 +5,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"rag/common/eino"
|
"rag/common/eino"
|
||||||
"rag/common/gse"
|
"rag/common/task"
|
||||||
"rag/consts/document"
|
"rag/consts/document"
|
||||||
"rag/consts/public"
|
"rag/consts/public"
|
||||||
"rag/dao"
|
"rag/dao"
|
||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"gitea.com/red-future/common/beans"
|
|
||||||
"gitea.com/red-future/common/db/gfdb"
|
"gitea.com/red-future/common/db/gfdb"
|
||||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||||
"gitea.com/red-future/common/http"
|
"gitea.com/red-future/common/http"
|
||||||
@@ -29,6 +26,7 @@ import (
|
|||||||
"github.com/gogf/gf/v2/database/gdb"
|
"github.com/gogf/gf/v2/database/gdb"
|
||||||
"github.com/gogf/gf/v2/database/gredis"
|
"github.com/gogf/gf/v2/database/gredis"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/os/grpool"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -54,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
res = &dto.CreateDocumentRes{Id: id}
|
res = &dto.CreateDocumentRes{Id: id}
|
||||||
|
// 写入任务进度待处理 任务类型为文档解析
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: id,
|
||||||
|
TaskType: task.TaskTypeDocParse,
|
||||||
|
Status: task.TaskStatusPending,
|
||||||
|
Remark: "文档上传成功待解析: " + req.Title,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -79,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
|||||||
DocumentCount: -1,
|
DocumentCount: -1,
|
||||||
DocumentSize: -docs.FileSize,
|
DocumentSize: -docs.FileSize,
|
||||||
}
|
}
|
||||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
|
||||||
if err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = dao.Document.Delete(ctx, req)
|
|
||||||
|
if _, err = dao.Document.Delete(ctx, req); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||||||
|
TaskId: docs.Id,
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -107,118 +120,159 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
|
|||||||
Total: total,
|
Total: total,
|
||||||
}
|
}
|
||||||
err = gconv.Struct(list, &res.List)
|
err = gconv.Struct(list, &res.List)
|
||||||
|
|
||||||
//eino.TestIndexer()
|
|
||||||
//eino.TestRetriever()
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process 处理文件(使用eino框架切分和向量化)
|
// Process 处理文件(使用eino框架切分和向量化)
|
||||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
// 1. 查询文件信息
|
// 1. 查询文件信息
|
||||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
if g.IsEmpty(doc) {
|
if g.IsEmpty(doc) {
|
||||||
return nil, errors.New("document not found")
|
return errors.New("document not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 使用eino框架进行文件切分(并发执行)
|
// 2. 更新文档状态为处理中
|
||||||
var vectorDocsCount, chunks int64
|
|
||||||
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
|
|
||||||
var err1, err2, err3 error
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(3)
|
|
||||||
|
|
||||||
// 任务1
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 任务2
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
err2 = s.esSplitDocument(ctx, doc)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 任务3
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
err3 = s.extractDocument(ctx, doc)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 直接等待,不使用通道,避免泄漏
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||||||
updateDocumentReq.Id = req.Id
|
updateDocumentReq.Id = req.Id
|
||||||
|
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||||
// 统一判断错误
|
|
||||||
if err1 != nil || err2 != nil || err3 != nil {
|
|
||||||
// 更新文档状态
|
|
||||||
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
|
|
||||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err1 != nil {
|
|
||||||
return nil, err1
|
|
||||||
}
|
|
||||||
if err2 != nil {
|
|
||||||
return nil, err2
|
|
||||||
}
|
|
||||||
return nil, err3
|
|
||||||
}
|
|
||||||
|
|
||||||
// 4. 更新文件状态为处理中和切分数量
|
|
||||||
if vectorDocsCount > 0 {
|
|
||||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
|
||||||
} else {
|
|
||||||
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
|
|
||||||
}
|
|
||||||
updateDocumentReq.ChunkCount = chunks
|
|
||||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为文档解析
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: req.Id,
|
||||||
|
TaskType: task.TaskTypeDocParse,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "更新文档状态失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 写入任务进度进行中 任务类型为文档解析
|
||||||
costTime := time.Since(startTime).Milliseconds()
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: req.Id,
|
||||||
return &dto.ProcessDocumentRes{
|
TaskType: task.TaskTypeDocParse,
|
||||||
ChunkCount: chunks,
|
Status: task.TaskStatusRunning,
|
||||||
CostTime: costTime,
|
Remark: "文档解析开始",
|
||||||
}, nil
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
|
||||||
// 1. 加载文件
|
|
||||||
docs, err := s.loadDocument(ctx, doc)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var words []gse.Keyword
|
// ======================
|
||||||
|
// 核心:grpool + g.Try 最佳实践
|
||||||
|
// ======================
|
||||||
|
taskCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
// 任务1: SQL 切分文档
|
||||||
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
|
g.TryCatch(ctx, func(ctx context.Context) {
|
||||||
|
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}, func(ctx context.Context, err error) {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 任务2: ES 切分文档
|
||||||
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
|
g.TryCatch(ctx, func(ctx context.Context) {
|
||||||
|
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}, func(ctx context.Context, err error) {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 任务3: 提取文档
|
||||||
|
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||||
|
g.TryCatch(ctx, func(ctx context.Context) {
|
||||||
|
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}, func(ctx context.Context, err error) {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractDocument 关键词提取(支持取消)
|
||||||
|
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||||
|
// ========== 取消检查 1:方法入口 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
// 写入任务进度失败 任务类型为关键字存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 加载文件
|
||||||
|
docs, err := s.loadDocument(ctx, doc)
|
||||||
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为关键字存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "加载文件失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var words []utils.Keyword
|
||||||
if len(docs[0].Content) < 500 {
|
if len(docs[0].Content) < 500 {
|
||||||
words = gse.GseTool.Extract(docs[0].Content, 4)
|
words = utils.GseTool.Extract(docs[0].Content, 4)
|
||||||
} else if len(docs[0].Content) < 2000 {
|
} else if len(docs[0].Content) < 2000 {
|
||||||
words = gse.GseTool.Extract(docs[0].Content, 8)
|
words = utils.GseTool.Extract(docs[0].Content, 8)
|
||||||
} else if len(docs[0].Content) < 5000 {
|
} else if len(docs[0].Content) < 5000 {
|
||||||
words = gse.GseTool.Extract(docs[0].Content, 13)
|
words = utils.GseTool.Extract(docs[0].Content, 13)
|
||||||
} else {
|
} else {
|
||||||
var docsSplit []*schema.Document
|
var docsSplit []*schema.Document
|
||||||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "递归分割文档失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// ========== 取消检查 2:循环内部 ==========
|
||||||
for _, t := range docsSplit {
|
for _, t := range docsSplit {
|
||||||
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== 取消检查 3:批量操作前 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||||
@@ -231,37 +285,111 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
|||||||
if len(keywordReqs) > 0 {
|
if len(keywordReqs) > 0 {
|
||||||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为关键字存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "关键字存储失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 写入任务进度已完成 任务类型为关键字存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
Remark: "关键字提取完成",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 写入任务进度已完成 任务类型为关键字存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeExtractKeywords,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
Remark: "没有提取到关键词,关键字提取完成",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
|
// sqlSplitDocument SQL切分(支持取消)
|
||||||
|
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||||
|
// ========== 取消检查 1:方法入口 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 加载文件
|
// 1. 加载文件
|
||||||
docs, err := s.loadDocument(ctx, doc)
|
docs, err := s.loadDocument(ctx, doc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "加载文件失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 语义切分文件
|
// 2. 语义切分文件
|
||||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "文档切分失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
docsSplitCount = gconv.Int64(len(docsSplit))
|
|
||||||
// 2. 获取历史数据
|
// 2. 获取历史数据
|
||||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "获取历史数据失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 组装向量文档
|
// 3. 组装向量文档
|
||||||
var docsChunk = make([]*schema.Document, 0)
|
var docsChunk = make([]*schema.Document, 0)
|
||||||
for i, t := range docsSplit {
|
for i, t := range docsSplit {
|
||||||
|
// ========== 取消检查 2:循环内部 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
contentHash := gmd5.MustEncryptString(t.Content)
|
contentHash := gmd5.MustEncryptString(t.Content)
|
||||||
// 检查是否重复
|
|
||||||
var success bool
|
var success bool
|
||||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "检查重复数据失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !success {
|
if !success {
|
||||||
@@ -277,6 +405,18 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
|||||||
t.MetaData = metaData
|
t.MetaData = metaData
|
||||||
docsChunk = append(docsChunk, t)
|
docsChunk = append(docsChunk, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== 取消检查 3:批量发送前 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// 4. 发送消息到队列
|
// 4. 发送消息到队列
|
||||||
if len(docsChunk) > 0 {
|
if len(docsChunk) > 0 {
|
||||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||||
@@ -285,41 +425,117 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu
|
|||||||
Data: docsChunk,
|
Data: docsChunk,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "发送消息到队列失败: " + err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 写入任务进度进行中 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusRunning,
|
||||||
|
Remark: "向量生成任务已提交到队列",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 写入任务进度已完成 任务类型为sql存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
Remark: "无需生成向量,任务完成",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
vectorDocsCount = gconv.Int64(len(docsChunk))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// esSplitDocument ES切分(支持取消)
|
||||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||||
|
// ========== 取消检查 1:方法入口 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 加载文件
|
// 1. 加载文件
|
||||||
docs, err := s.loadDocument(ctx, doc)
|
docs, err := s.loadDocument(ctx, doc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为es存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "加载文件失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 递归切分文件
|
// 2. 递归切分文件
|
||||||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为es存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "文档切分失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 获取历史数据
|
// 2. 获取历史数据
|
||||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为es存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "获取历史数据失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 组装向量文档并同时构建meilisearch文档
|
// 3. 组装向量文档并同时构建meilisearch文档
|
||||||
var meiliDocs = make([]interface{}, 0)
|
var meiliDocs = make([]interface{}, 0)
|
||||||
for i, t := range docsSplit {
|
for i, t := range docsSplit {
|
||||||
|
// ========== 取消检查 2:循环内部 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
contentHash := gmd5.MustEncryptString(t.Content)
|
contentHash := gmd5.MustEncryptString(t.Content)
|
||||||
// 检查是否重复
|
|
||||||
var success bool
|
var success bool
|
||||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// 写入任务进度失败 任务类型为es存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "检查重复数据失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !success {
|
if !success {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// 构建Meilisearch文档
|
|
||||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||||
entity.DocumentChunkCol.Id: contentHash,
|
entity.DocumentChunkCol.Id: contentHash,
|
||||||
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
||||||
@@ -329,12 +545,45 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum
|
|||||||
entity.DocumentChunkCol.ChunkIndex: i,
|
entity.DocumentChunkCol.ChunkIndex: i,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== 取消检查 3:批量写入前 ==========
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||||
|
})
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
// 4. 写入到meilisearch数据库中
|
// 4. 写入到meilisearch数据库中
|
||||||
if len(meiliDocs) > 0 {
|
if len(meiliDocs) > 0 {
|
||||||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||||||
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
|
// 写入任务进度失败 任务类型为meilisearch存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: "写入meilisearch失败: " + err.Error(),
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
Remark: "全文检索数据写入完成",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: doc.Id,
|
||||||
|
TaskType: task.TaskTypeFullTextSearch,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
Remark: "无需生成全文检索数据,任务完成",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -467,20 +716,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa
|
|||||||
success = val.Bool()
|
success = val.Bool()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
|
|
||||||
var req = new(dto.KnowledgeDocumentMsg)
|
|
||||||
if err = gconv.Struct(msg, &req); err != nil {
|
|
||||||
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
|
||||||
TenantId: req.TenantId,
|
|
||||||
UserName: req.Creator,
|
|
||||||
})
|
|
||||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
|
||||||
Id: req.Id,
|
|
||||||
VectorStatus: req.VectorStatus,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,15 +3,11 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"rag/common/eino"
|
"rag/common/eino"
|
||||||
"rag/consts/document"
|
"rag/common/task"
|
||||||
"rag/consts/public"
|
|
||||||
"rag/dao"
|
"rag/dao"
|
||||||
"rag/model/dto"
|
"rag/model/dto"
|
||||||
"rag/model/entity"
|
"rag/model/entity"
|
||||||
|
|
||||||
gmq "github.com/bjang03/gmq/core/gmq"
|
|
||||||
"github.com/bjang03/gmq/mq"
|
|
||||||
"github.com/bjang03/gmq/types"
|
|
||||||
"github.com/cloudwego/eino/components/indexer"
|
"github.com/cloudwego/eino/components/indexer"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
@@ -22,10 +18,6 @@ var DocumentChunk = new(documentChunkService)
|
|||||||
|
|
||||||
type documentChunkService struct{}
|
type documentChunkService struct{}
|
||||||
|
|
||||||
const (
|
|
||||||
DatasetIndexStatusReady = "ready"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Update 更新文件块
|
// Update 更新文件块
|
||||||
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
||||||
_, err = dao.DocumentChunk.Update(ctx, req)
|
_, err = dao.DocumentChunk.Update(ctx, req)
|
||||||
@@ -60,32 +52,29 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e
|
|||||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||||
BatchSize: 10,
|
BatchSize: 10,
|
||||||
})
|
})
|
||||||
|
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||||
if err != nil || rows == 0 {
|
if err != nil || rows == 0 {
|
||||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||||
|
// 写入任务进度失败 任务类型为sql存储
|
||||||
|
remark := " 向量存储数量: " + gconv.String(rows)
|
||||||
|
if err != nil {
|
||||||
|
remark = "向量存储失败: " + err.Error()
|
||||||
|
}
|
||||||
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
|
TaskId: documentId,
|
||||||
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusFailed,
|
||||||
|
Remark: remark,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId])
|
// 写入任务进度成功 任务类型为sql存储
|
||||||
creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator])
|
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
TaskId: documentId,
|
||||||
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
|
TaskType: task.TaskTypeGenerateVector,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
return
|
Remark: "向量生成完成",
|
||||||
}
|
|
||||||
|
|
||||||
// publishKnowledgeDocumentMsg 发布消息
|
|
||||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
|
||||||
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
|
|
||||||
TenantId: tenantId,
|
|
||||||
Creator: creator,
|
|
||||||
Id: documentId,
|
|
||||||
VectorStatus: vectorStatus,
|
|
||||||
}
|
|
||||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
|
||||||
PubMessage: types.PubMessage{
|
|
||||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
|
||||||
Data: knowledgeDocumentMsg,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
52
service/rag_query.go
Normal file
52
service/rag_query.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"rag/common/eino"
|
||||||
|
"rag/model/dto"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/components/retriever"
|
||||||
|
"github.com/gogf/gf/v2/os/glog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var RAGQuery = new(ragQueryService)
|
||||||
|
|
||||||
|
type ragQueryService struct{}
|
||||||
|
|
||||||
|
// Query 执行RAG查询
|
||||||
|
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||||
|
if req.TopK <= 0 {
|
||||||
|
req.TopK = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 使用向量检索器进行查询
|
||||||
|
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||||
|
Embedder: eino.EmbedderDashscope,
|
||||||
|
DefaultTopK: req.TopK,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||||
|
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 执行向量检索
|
||||||
|
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||||
|
"dataset_ids": req.DatasetIds,
|
||||||
|
}))
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||||
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs)
|
||||||
|
if err != nil {
|
||||||
|
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||||
|
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dto.RAGQueryRes{
|
||||||
|
Answer: replyMsg.Content,
|
||||||
|
Sources: sources,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
107
service/task.go
Normal file
107
service/task.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"rag/dao"
|
||||||
|
"rag/model/dto"
|
||||||
|
|
||||||
|
"rag/common/task"
|
||||||
|
|
||||||
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Task = new(taskService)
|
||||||
|
|
||||||
|
type taskService struct{}
|
||||||
|
|
||||||
|
// WriteTaskProgress 写入任务进度(核心方法)
|
||||||
|
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
|
||||||
|
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
taskVO := make([]dto.TaskVO, 0, total)
|
||||||
|
err = gconv.Struct(t, taskVO)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
taskVO = append(taskVO, dto.TaskVO{
|
||||||
|
TaskType: req.TaskType,
|
||||||
|
Status: req.Status,
|
||||||
|
})
|
||||||
|
completed := IsAllSubTasksCompleted(taskVO)
|
||||||
|
|
||||||
|
// 1. 查询是否已存在该文档的该类型任务
|
||||||
|
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
TaskType: req.TaskType,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 如果不存在,则创建新任务
|
||||||
|
if g.IsEmpty(existTask) {
|
||||||
|
createReq := &dto.CreateTaskReq{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
TaskType: req.TaskType,
|
||||||
|
Status: req.Status,
|
||||||
|
Remark: req.Remark,
|
||||||
|
}
|
||||||
|
_, err = dao.Task.Insert(ctx, createReq)
|
||||||
|
} else {
|
||||||
|
// 3. 如果已存在,则更新任务
|
||||||
|
updateReq := &dto.UpdateTaskReq{
|
||||||
|
Id: existTask[0].Id,
|
||||||
|
Status: req.Status,
|
||||||
|
Remark: req.Remark,
|
||||||
|
}
|
||||||
|
_, err = dao.Task.Update(ctx, updateReq)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if completed {
|
||||||
|
// 3. 如果已存在,则更新任务
|
||||||
|
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
||||||
|
TaskId: req.TaskId,
|
||||||
|
Status: task.TaskStatusCompleted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||||
|
// 参数:传入当前文档的所有子任务列表
|
||||||
|
func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool {
|
||||||
|
// 必须包含 3 种任务类型
|
||||||
|
hasKeywords := false
|
||||||
|
hasVector := false
|
||||||
|
hasFullText := false
|
||||||
|
|
||||||
|
for _, t := range subTasks {
|
||||||
|
// 子任务必须是【已完成】状态才计数
|
||||||
|
if t.Status == task.TaskStatusCompleted {
|
||||||
|
switch t.TaskType {
|
||||||
|
case task.TaskTypeExtractKeywords:
|
||||||
|
hasKeywords = true
|
||||||
|
case task.TaskTypeGenerateVector:
|
||||||
|
hasVector = true
|
||||||
|
case task.TaskTypeFullTextSearch:
|
||||||
|
hasFullText = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 三个任务全部完成 → 返回true
|
||||||
|
return hasKeywords && hasVector && hasFullText
|
||||||
|
}
|
||||||
44
update.sql
44
update.sql
@@ -114,6 +114,7 @@ COMMENT ON COLUMN rag_knowledge_document.file_path IS '文件存储路径(如M
|
|||||||
COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}';
|
COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}';
|
||||||
|
|
||||||
--------------------pgsql创建rag_knowledge_document表语句---------------------------
|
--------------------pgsql创建rag_knowledge_document表语句---------------------------
|
||||||
|
|
||||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||||
-- 关键词表(文档关键词+权重)
|
-- 关键词表(文档关键词+权重)
|
||||||
CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
|
CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
|
||||||
@@ -161,6 +162,49 @@ COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
|
|||||||
|
|
||||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||||
|
|
||||||
|
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||||
|
-- 知识库任务表
|
||||||
|
CREATE TABLE IF NOT EXISTS rag_knowledge_task (
|
||||||
|
-- 基础字段(完全对齐项目规范)
|
||||||
|
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||||
|
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
|
||||||
|
creator VARCHAR(64) NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updater VARCHAR(64) NOT NULL,
|
||||||
|
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
deleted_at timestamp(6),
|
||||||
|
|
||||||
|
-- 业务字段
|
||||||
|
task_id BIGINT NOT NULL, -- 任务ID
|
||||||
|
task_type VARCHAR(32) NOT NULL, -- 任务类型
|
||||||
|
status VARCHAR(32) NOT NULL, -- 任务状态
|
||||||
|
executor VARCHAR(128) DEFAULT '', -- 执行器
|
||||||
|
remark TEXT DEFAULT '' -- 备注
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 索引(高频查询)
|
||||||
|
CREATE INDEX idx_rkt_tenant_id ON rag_knowledge_task(tenant_id);
|
||||||
|
CREATE INDEX idx_rkt_task_id ON rag_knowledge_task(task_id);
|
||||||
|
CREATE INDEX idx_rkt_task_type ON rag_knowledge_task(task_type);
|
||||||
|
CREATE INDEX idx_rkt_status ON rag_knowledge_task(status);
|
||||||
|
CREATE INDEX idx_rkt_deleted_at ON rag_knowledge_task(deleted_at);
|
||||||
|
|
||||||
|
-- 表和字段注释
|
||||||
|
COMMENT ON TABLE rag_knowledge_task IS '知识库任务表';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.id IS '主键ID(非自增)';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.tenant_id IS '租户ID';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.creator IS '创建人';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.created_at IS '创建时间';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.updater IS '更新人';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.updated_at IS '更新时间';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.deleted_at IS '删除时间(软删)';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.task_id IS '任务ID';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.task_type IS '任务类型';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.status IS '任务状态';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.executor IS '执行器';
|
||||||
|
COMMENT ON COLUMN rag_knowledge_task.remark IS '备注';
|
||||||
|
|
||||||
|
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||||
|
|
||||||
|
|
||||||
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
|
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user