Compare commits
16 Commits
master
...
b2f7cff277
| Author | SHA1 | Date | |
|---|---|---|---|
| b2f7cff277 | |||
| 93aef365e7 | |||
| cfcf705503 | |||
| 2ced0a43e5 | |||
| 14a429f4ae | |||
| ff5fc54b35 | |||
| 7f894745e9 | |||
| e5a27c00ed | |||
| b6896f3fb4 | |||
| 026beea4d9 | |||
| 86c2b7d66e | |||
| 722fbe0cc3 | |||
| 6d68b468a6 | |||
| b33d50944a | |||
| 6f2df61bc5 | |||
| b00d544fb7 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1 +1 @@
|
||||
/.idea/*
|
||||
/.idea/*
|
||||
24
Dockerfile
Normal file
24
Dockerfile
Normal file
@@ -0,0 +1,24 @@
|
||||
# 最小化Docker镜像
|
||||
FROM busybox:uclibc
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制时区数据
|
||||
COPY timezone/localtime /etc/localtime
|
||||
COPY timezone/timezone /etc/timezone
|
||||
COPY timezone/Shanghai /usr/share/zoneinfo/Asia/Shanghai
|
||||
|
||||
# 复制预构建的二进制文件和配置文件
|
||||
COPY rag_binary ./main
|
||||
COPY config.yml ./
|
||||
|
||||
# 创建日志目录
|
||||
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
|
||||
|
||||
# 添加执行权限
|
||||
RUN chmod +x /app/main
|
||||
|
||||
EXPOSE 3006
|
||||
|
||||
# 使用root用户运行
|
||||
CMD ["./main"]
|
||||
156
common/eino/chat_model.go
Normal file
156
common/eino/chat_model.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
const (
|
||||
MaxHistoryTurns = 5 // 最大历史轮数
|
||||
)
|
||||
|
||||
var (
|
||||
globalChatModel *qwen.ChatModel
|
||||
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
|
||||
)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
// 初始化大模型
|
||||
if err := initChatModel(ctx); err != nil {
|
||||
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||
}
|
||||
// 初始化 EINO 提示词模板
|
||||
initRAGPromptTemplate()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 初始化通义千问
|
||||
func initChatModel(ctx context.Context) error {
|
||||
if globalChatModel != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||
|
||||
cm, err := qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
Timeout: 60 * 1e9,
|
||||
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
globalChatModel = cm
|
||||
return nil
|
||||
}
|
||||
|
||||
// 初始化 EINO 官方提示词模板(最关键!)
|
||||
func initRAGPromptTemplate() {
|
||||
ragPromptTemplate = prompt.FromMessages(
|
||||
schema.FString,
|
||||
// 系统提示(带参考知识)
|
||||
&schema.Message{
|
||||
Role: schema.System,
|
||||
Content: `你是专业客服,语气友好简洁。
|
||||
请严格依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
|
||||
|
||||
参考知识:
|
||||
{knowledge}`,
|
||||
},
|
||||
// 用户问题
|
||||
&schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "{question}",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// NewChatModel 只处理逻辑,不复用创建模型
|
||||
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message) (replyMsg *schema.Message, err error) {
|
||||
// 1. 构建参考知识
|
||||
knowledge := buildKnowledgeAndSources(docs)
|
||||
// 2. 历史精简
|
||||
history = limitHistory(history)
|
||||
// 3. ✅ EINO 官方模板格式化(超级干净)
|
||||
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
|
||||
"knowledge": knowledge,
|
||||
"question": question,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 4. 历史插入到模板消息中间(标准EINO用法)
|
||||
if len(history) > 0 {
|
||||
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
|
||||
}
|
||||
// 5. 🔥 直接使用全局单例,不重复创建
|
||||
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func limitHistory(history []*schema.Message) []*schema.Message {
|
||||
valid := make([]*schema.Message, 0, len(history))
|
||||
for _, m := range history {
|
||||
if m.Role == schema.User || m.Role == schema.Assistant {
|
||||
valid = append(valid, m)
|
||||
}
|
||||
}
|
||||
|
||||
keep := 2 * MaxHistoryTurns
|
||||
if len(valid) > keep {
|
||||
valid = valid[len(valid)-keep:]
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
// buildKnowledgeAndSources 拼接参考知识
|
||||
func buildKnowledgeAndSources(docs []*schema.Document) string {
|
||||
var knowledge string
|
||||
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
|
||||
}
|
||||
return knowledge
|
||||
}
|
||||
|
||||
// streamGenerateAnswer 流式生成
|
||||
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
|
||||
|
||||
sr, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream failed: %w", err)
|
||||
}
|
||||
|
||||
var chunks []*schema.Message
|
||||
for {
|
||||
chunk, err := sr.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream recv failed: %w", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return schema.ConcatMessages(chunks)
|
||||
}
|
||||
8
common/eino/consts.go
Normal file
8
common/eino/consts.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package eino
|
||||
|
||||
const (
|
||||
providerArk = "ark"
|
||||
providerOpenai = "openai"
|
||||
providerQianfan = "qianfan"
|
||||
providerDashscope = "dashscope"
|
||||
)
|
||||
51
common/eino/document_loader.go
Normal file
51
common/eino/document_loader.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino-ext/components/document/loader/url"
|
||||
"github.com/cloudwego/eino-ext/components/document/parser/docx"
|
||||
"github.com/cloudwego/eino-ext/components/document/parser/pdf"
|
||||
"github.com/cloudwego/eino-ext/components/document/parser/xlsx"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/components/document/parser"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// LoadDocument 业务函数:加载文件
|
||||
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
|
||||
p, err := docsParser(ctx, fileFormat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
loader, err := url.NewLoader(ctx, &url.LoaderConfig{
|
||||
Parser: p,
|
||||
})
|
||||
imageUrl, err := utils.GetFileAddressPrefix(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
docs, err = loader.Load(context.Background(), document.Source{
|
||||
URI: fmt.Sprintf("%s%s", imageUrl, filePath),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) {
|
||||
switch fileFormat {
|
||||
case "docx":
|
||||
p, err = docx.NewDocxParser(ctx, &docx.Config{
|
||||
ToSections: true,
|
||||
IncludeHeaders: true,
|
||||
IncludeFooters: true,
|
||||
IncludeTables: true,
|
||||
})
|
||||
case "pdf":
|
||||
p, err = pdf.NewPDFParser(ctx, &pdf.Config{})
|
||||
case "xlsx":
|
||||
p, err = xlsx.NewXlsxParser(ctx, &xlsx.Config{})
|
||||
}
|
||||
return
|
||||
}
|
||||
64
common/eino/document_semantic.go
Normal file
64
common/eino/document_semantic.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
|
||||
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// SemanticSplitDocument 语义分割文档
|
||||
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
// 读取配置,使用合理的默认值
|
||||
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
|
||||
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
|
||||
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
|
||||
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
|
||||
}
|
||||
|
||||
// 使用批量包装器
|
||||
var batchEmbedder *BatchEmbedder
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
|
||||
case providerOpenai:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
|
||||
case providerDashscope:
|
||||
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
|
||||
}
|
||||
|
||||
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
|
||||
Embedding: batchEmbedder,
|
||||
BufferSize: bufferSize,
|
||||
MinChunkSize: minChunkSize,
|
||||
Percentile: percentile,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return splitter.Transform(ctx, docs)
|
||||
}
|
||||
|
||||
// RecursiveSplitDocument 递归分割文档
|
||||
func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
|
||||
// 默认分隔符(支持中英文)
|
||||
separators := []string{"\n\n", "\n", "。", "!", "?", ";", ".", "!", "?", ";"}
|
||||
splitter, err := recursive.NewSplitter(ctx, &recursive.Config{
|
||||
ChunkSize: 512,
|
||||
OverlapSize: 100,
|
||||
KeepType: recursive.KeepTypeNone,
|
||||
Separators: separators,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return splitter.Transform(ctx, docs)
|
||||
}
|
||||
69
common/eino/embedding.go
Normal file
69
common/eino/embedding.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/embedding/ark"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
|
||||
"github.com/cloudwego/eino-ext/components/embedding/openai"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// 全局只初始化一次
|
||||
var (
|
||||
EmbedderArk *ark.Embedder
|
||||
EmbedderDashscope *dashscope.Embedder
|
||||
EmbedderOpenAI *openai.Embedder
|
||||
)
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
|
||||
var err error
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
cfg := &ark.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
|
||||
apiTypeVal := ark.APIType(apiType)
|
||||
cfg.APIType = &apiTypeVal
|
||||
}
|
||||
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
|
||||
case providerOpenai:
|
||||
chatModelConfig := &openai.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
|
||||
case providerDashscope:
|
||||
cfg := &dashscope.EmbeddingConfig{
|
||||
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
|
||||
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
|
||||
}
|
||||
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
|
||||
}
|
||||
if err != nil {
|
||||
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
|
||||
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
|
||||
switch provider {
|
||||
case providerArk:
|
||||
return EmbedderArk.EmbedStrings(ctx, texts)
|
||||
case providerOpenai:
|
||||
return EmbedderOpenAI.EmbedStrings(ctx, texts)
|
||||
case providerDashscope:
|
||||
return EmbedderDashscope.EmbedStrings(ctx, texts)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported provider: %v", provider)
|
||||
}
|
||||
47
common/eino/embedding_batch.go
Normal file
47
common/eino/embedding_batch.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
)
|
||||
|
||||
// BatchEmbedder 包装器,支持批量限制
|
||||
type BatchEmbedder struct {
|
||||
embedder embedding.Embedder
|
||||
batchSize int
|
||||
}
|
||||
|
||||
// NewBatchEmbedder 创建支持批量限制的 embedding 包装器
|
||||
func NewBatchEmbedder(embedder embedding.Embedder, batchSize int) *BatchEmbedder {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 10 // 默认每批 10 个
|
||||
}
|
||||
return &BatchEmbedder{
|
||||
embedder: embedder,
|
||||
batchSize: batchSize,
|
||||
}
|
||||
}
|
||||
|
||||
// EmbedStrings 分批调用 embedding
|
||||
func (b *BatchEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
||||
if len(texts) <= b.batchSize {
|
||||
return b.embedder.EmbedStrings(ctx, texts, opts...)
|
||||
}
|
||||
|
||||
var allEmbeddings [][]float64
|
||||
for i := 0; i < len(texts); i += b.batchSize {
|
||||
end := i + b.batchSize
|
||||
if end > len(texts) {
|
||||
end = len(texts)
|
||||
}
|
||||
|
||||
batch := texts[i:end]
|
||||
embeddings, err := b.embedder.EmbedStrings(ctx, batch, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allEmbeddings = append(allEmbeddings, embeddings...)
|
||||
}
|
||||
return allEmbeddings, nil
|
||||
}
|
||||
177
common/eino/indexer.go
Normal file
177
common/eino/indexer.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
type PGVectorIndexerOptions struct {
|
||||
BatchSize int // 每批处理多少条
|
||||
}
|
||||
|
||||
type PGVectorIndexer struct {
|
||||
opts *PGVectorIndexerOptions
|
||||
}
|
||||
|
||||
func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
|
||||
// 默认值
|
||||
if opts.BatchSize <= 0 {
|
||||
opts.BatchSize = 5
|
||||
}
|
||||
return &PGVectorIndexer{opts: opts}
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
|
||||
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
|
||||
|
||||
if commonOpts.Embedding == nil {
|
||||
return 0, errors.New("embedding model not set")
|
||||
}
|
||||
|
||||
// 回调
|
||||
ctx = callbacks.OnStart(ctx, &indexer.CallbackInput{Docs: docs})
|
||||
|
||||
rows, err = i.bulkStore(ctx, docs, commonOpts)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &indexer.CallbackOutput{IDs: gconv.Strings(rows)})
|
||||
return
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) bulkStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||||
var batchDocs []*schema.Document
|
||||
|
||||
// 官方ES同款逻辑:满 BatchSize 就处理一批
|
||||
for _, doc := range docs {
|
||||
batchDocs = append(batchDocs, doc)
|
||||
|
||||
// 满了 → 处理
|
||||
if len(batchDocs) >= i.opts.BatchSize {
|
||||
var r int64
|
||||
r, err = i.doStore(ctx, batchDocs, opts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rows = rows + r
|
||||
batchDocs = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 最后一批
|
||||
if len(batchDocs) > 0 {
|
||||
var r int64
|
||||
r, err = i.doStore(ctx, batchDocs, opts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rows = rows + r
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document, opts *indexer.Options) (rows int64, err error) {
|
||||
|
||||
texts := make([]string, len(docs))
|
||||
for i, d := range docs {
|
||||
texts[i] = d.Content
|
||||
}
|
||||
|
||||
// 向量化(官方ES也没有重试!)
|
||||
vectors, err := opts.Embedding.EmbedStrings(ctx, texts)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 转成业务实体
|
||||
var chunks []*dto.VectorDocumentChunkMsg
|
||||
for idx, doc := range docs {
|
||||
ck := new(dto.VectorDocumentChunkMsg)
|
||||
err = gconv.Struct(doc.MetaData, ck)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "doStore err: %v", err)
|
||||
continue
|
||||
}
|
||||
ck.Content = doc.Content
|
||||
ck.Vector = pgvector.NewVector(gconv.Float32s(vectors[idx]))
|
||||
ck.VectorStatus = gconv.PtrInt8(1)
|
||||
ck.Status = gconv.PtrInt8(1)
|
||||
chunks = append(chunks, ck)
|
||||
}
|
||||
if len(chunks) == 0 {
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: chunks[0].TenantId,
|
||||
UserName: chunks[0].Creator,
|
||||
})
|
||||
// 创建索引
|
||||
if err = i.createOrUpdateDatasetIndex(ctx, chunks[0].DatasetId, len(vectors[0]), int64(len(chunks))); err != nil {
|
||||
return
|
||||
}
|
||||
// 入库
|
||||
rows, err = dao.DocumentChunk.BatchInsert(ctx, chunks)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) error {
|
||||
exist, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
if exist != nil {
|
||||
_ = dao.DatasetIndex.IncVectorCount(ctx, exist.Id, vectorCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId)
|
||||
idx := &entity.DatasetIndex{
|
||||
DatasetId: datasetId,
|
||||
Name: indexName,
|
||||
Dimension: dimension,
|
||||
FieldType: "float",
|
||||
MetricType: "COSINE",
|
||||
Status: gconv.PtrInt8(1),
|
||||
VectorCount: vectorCount,
|
||||
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||||
}
|
||||
_, err = dao.DatasetIndex.Insert(ctx, idx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.createRealPGVectorIndex(ctx, indexName)
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||||
if err := dao.DatasetIndex.InsertIndex(ctx, indexName); err != nil {
|
||||
glog.Errorf(ctx, "create vector index failed: %v", err)
|
||||
return err
|
||||
}
|
||||
glog.Infof(ctx, "created pgvector index: %s", indexName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) GetType() string {
|
||||
return "pgvector_indexer"
|
||||
}
|
||||
|
||||
func (i *PGVectorIndexer) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
292
common/eino/retriever.go
Normal file
292
common/eino/retriever.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"rag/dao"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
type PGVectorRetrieverConfig struct {
|
||||
Embedder embedding.Embedder
|
||||
DefaultTopK int
|
||||
DefaultIndex string
|
||||
DSLInfo map[string]any
|
||||
}
|
||||
|
||||
type PGVectorRetriever struct {
|
||||
embedder embedding.Embedder
|
||||
topK int
|
||||
index string
|
||||
dslInfo map[string]any
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
if config.Embedder == nil {
|
||||
return nil, errors.New("embedder is required")
|
||||
}
|
||||
if config.DefaultTopK <= 0 {
|
||||
config.DefaultTopK = 5
|
||||
}
|
||||
|
||||
return &PGVectorRetriever{
|
||||
embedder: config.Embedder,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
dslInfo: config.DSLInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
options := &retriever.Options{
|
||||
Index: &r.index,
|
||||
TopK: &r.topK,
|
||||
DSLInfo: r.dslInfo,
|
||||
Embedding: r.embedder,
|
||||
}
|
||||
options = retriever.GetCommonOptions(options, opts...)
|
||||
|
||||
// 安全保护:防止 nil 指针 panic
|
||||
topK := 10
|
||||
if options.TopK != nil {
|
||||
topK = *options.TopK
|
||||
}
|
||||
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: query,
|
||||
TopK: *options.TopK,
|
||||
})
|
||||
|
||||
// ==========================================
|
||||
// 🔥 优化版:grpool 并行双路检索(安全、健壮、无泄漏)
|
||||
// ==========================================
|
||||
var (
|
||||
docsVector []*schema.Document
|
||||
docsFulltext []*schema.Document
|
||||
errVector error
|
||||
errFulltext error
|
||||
|
||||
// 缓冲通道=2,确保无死锁等待
|
||||
done = make(chan struct{}, 2)
|
||||
)
|
||||
|
||||
// 上下文:超时 + 可取消双保障(建议5s超时,根据业务调整)
|
||||
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 封装并行任务函数,消除重复代码
|
||||
runTask := func(task func() error, errTarget *error) {
|
||||
defer func() {
|
||||
// 任务结束必发信号,确保通道不阻塞
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
// 捕获 panic + 执行业务逻辑
|
||||
g.TryCatch(taskCtx, func(ctx context.Context) {
|
||||
*errTarget = task()
|
||||
}, func(ctx context.Context, panicErr error) {
|
||||
*errTarget = panicErr
|
||||
})
|
||||
|
||||
// 任务失败:立即取消另一个任务(快速失败)
|
||||
if *errTarget != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------
|
||||
// 并行提交两个检索任务
|
||||
// ----------------------
|
||||
// 任务1:向量检索
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
runTask(func() error {
|
||||
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
|
||||
return errVector
|
||||
}, &errVector)
|
||||
})
|
||||
|
||||
// 任务2:全文检索
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
runTask(func() error {
|
||||
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
|
||||
return errFulltext
|
||||
}, &errFulltext)
|
||||
})
|
||||
|
||||
// ----------------------
|
||||
// 安全等待所有任务完成
|
||||
// ----------------------
|
||||
<-done
|
||||
<-done
|
||||
|
||||
// ----------------------
|
||||
// 统一错误处理
|
||||
// ----------------------
|
||||
// 用 errors.Join 合并所有错误,不丢失信息
|
||||
if err := errors.Join(errVector, errFulltext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 + 智能去重(保留最优分数)
|
||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
|
||||
// 排序:向量优先,同类型按距离升序
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
//byI, okI := docs[i].MetaData["retrieve_by"].(string)
|
||||
//byJ, okJ := docs[j].MetaData["retrieve_by"].(string)
|
||||
//
|
||||
//// 有类型标记的优先
|
||||
//if okI && !okJ {
|
||||
// return true
|
||||
//}
|
||||
//if !okI && okJ {
|
||||
// return false
|
||||
//}
|
||||
//
|
||||
//// 向量永远排前面
|
||||
//if byI == "vector" && byJ == "fulltext" {
|
||||
// return true
|
||||
//}
|
||||
//if byI == "fulltext" && byJ == "vector" {
|
||||
// return false
|
||||
//}
|
||||
|
||||
// 同类型按 distance 升序(越小越相似)
|
||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
|
||||
// 在Retrieve方法末尾,增加相关性校验
|
||||
validDocs := make([]*schema.Document, 0)
|
||||
for i, d := range docs {
|
||||
// 过滤distance过大的垃圾结果(比如distance>0.8的直接丢弃)
|
||||
if gconv.Float64(docs[i].MetaData["distance"]) < 0.8 {
|
||||
validDocs = append(validDocs, d)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有有效结果,返回空,让LLM回答「暂无相关信息」
|
||||
if len(validDocs) == 0 {
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||
return validDocs, nil
|
||||
}
|
||||
|
||||
// 最多保留 topK
|
||||
if len(validDocs) > topK {
|
||||
validDocs = validDocs[:topK]
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
|
||||
return validDocs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 1. 向量检索(PG)
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(vectors) == 0 {
|
||||
return nil, errors.New("empty query vector")
|
||||
}
|
||||
|
||||
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||||
topK := 10
|
||||
if opts.TopK != nil {
|
||||
topK = *opts.TopK
|
||||
}
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": gconv.Float64(row["distance"]),
|
||||
"retrieve_by": "vector",
|
||||
},
|
||||
})
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 全文检索(Meilisearch)🔥 新增
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
// 调用你已有的 Meilisearch DAO
|
||||
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
score := gconv.Float64(row["_rankingScore"])
|
||||
distance := score
|
||||
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": distance,
|
||||
"retrieve_by": "fulltext",
|
||||
},
|
||||
})
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
|
||||
// ==========================================
|
||||
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||||
idMap := make(map[string]*schema.Document)
|
||||
|
||||
// 先存入向量结果
|
||||
for _, d := range vecDocs {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
|
||||
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
|
||||
for _, d := range fullDocs {
|
||||
if existDoc, ok := idMap[d.ID]; ok {
|
||||
// 标记同时被向量和全文检索到
|
||||
existDoc.MetaData["retrieve_by"] = "both"
|
||||
} else {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
}
|
||||
|
||||
merged := make([]*schema.Document, 0, len(idMap))
|
||||
for _, d := range idMap {
|
||||
merged = append(merged, d)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
69
common/task/base_task.go
Normal file
69
common/task/base_task.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type baseTaskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskType string
|
||||
Status string
|
||||
Priority string
|
||||
ParentTaskID string
|
||||
TotalItems string
|
||||
ProcessedItems string
|
||||
Progress string
|
||||
StartTime string
|
||||
EndTime string
|
||||
Duration string
|
||||
SuccessCount string
|
||||
FailCount string
|
||||
Executor string
|
||||
DocumentID string
|
||||
Remark string
|
||||
}
|
||||
|
||||
var BaseTaskCol = baseTaskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskType: "task_type",
|
||||
Status: "status",
|
||||
Priority: "task_priority",
|
||||
ParentTaskID: "parent_task_id",
|
||||
TotalItems: "total_items",
|
||||
ProcessedItems: "processed_items",
|
||||
Progress: "progress",
|
||||
StartTime: "start_time",
|
||||
EndTime: "end_time",
|
||||
Duration: "duration",
|
||||
SuccessCount: "success_count",
|
||||
FailCount: "fail_count",
|
||||
Executor: "executor",
|
||||
DocumentID: "document_id",
|
||||
Remark: "remark",
|
||||
}
|
||||
|
||||
// SQLBaseTask 任务基类 - SQL版本
|
||||
type SQLBaseTask struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
// 任务核心信息
|
||||
TaskType TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||
Status TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||
Priority TaskPriority `orm:"task_priority" json:"priority,omitempty" dc:"任务优先级"`
|
||||
ParentTaskID int64 `orm:"parent_task_id" json:"parentTaskId,omitempty" dc:"父任务ID"`
|
||||
// 任务进度
|
||||
TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总数"`
|
||||
ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理数"`
|
||||
Progress float64 `orm:"progress" json:"progress" dc:"进度"` // 0~100 百分比
|
||||
// 任务结果
|
||||
StartTime *time.Time `orm:"start_time" json:"startTime" dc:"开始时间"`
|
||||
EndTime *time.Time `orm:"end_time" json:"endTime,omitempty" dc:"结束时间"`
|
||||
Duration int64 `orm:"duration" json:"duration,omitempty" dc:"耗时(毫秒)"`
|
||||
SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||
FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||
// 其他
|
||||
Executor string `orm:"executor" json:"executor,omitempty" dc:"执行器标识"`
|
||||
DocumentID int64 `orm:"document_id" json:"documentId,omitempty" dc:"文档ID"`
|
||||
Remark string `orm:"remark" json:"remark,omitempty" dc:"备注/错误信息"`
|
||||
}
|
||||
30
common/task/consts.go
Normal file
30
common/task/consts.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package task
|
||||
|
||||
// TaskType 任务类型枚举:文档解析的三个子任务
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词
|
||||
TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量
|
||||
TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索
|
||||
TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务
|
||||
)
|
||||
|
||||
// TaskStatus 任务状态枚举
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "PENDING" // 待执行
|
||||
TaskStatusRunning TaskStatus = "RUNNING" // 执行中
|
||||
TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成
|
||||
TaskStatusFailed TaskStatus = "FAILED" // 执行失败
|
||||
)
|
||||
|
||||
// TaskPriority 任务优先级
|
||||
type TaskPriority int
|
||||
|
||||
const (
|
||||
TaskPriorityLow TaskPriority = 1 // 低
|
||||
TaskPriorityMedium TaskPriority = 2 // 中
|
||||
TaskPriorityHigh TaskPriority = 3 // 高
|
||||
)
|
||||
149
config-dev.yml
Normal file
149
config-dev.yml
Normal file
@@ -0,0 +1,149 @@
|
||||
server:
|
||||
address: :3006
|
||||
name: rag
|
||||
workerId: 1
|
||||
|
||||
# Database.
|
||||
database:
|
||||
default:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "rag"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "slave" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: false # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_knowledge:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master"
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_vector:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_vector_" # (可选)表名前缀
|
||||
role: "master"
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: "116.204.74.41:6379"
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 116.204.74.41:8500
|
||||
|
||||
jaeger:
|
||||
addr: 116.204.74.41:4318
|
||||
|
||||
# eino框架配置
|
||||
eino:
|
||||
# 文件切分配置
|
||||
splitter:
|
||||
bufferSize: 1
|
||||
minChunkSize: 64
|
||||
percentile: 0.75
|
||||
# 向量化配置
|
||||
embedding:
|
||||
provider: "dashscope"
|
||||
# apiKey: "d158d896-8c54-40ee-9d61-4c5d37cd545c"
|
||||
# model: "ep-20260326123502-khmdq"
|
||||
# apiType: "multi_modal_api"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "text-embedding-v3"
|
||||
chatmodel:
|
||||
provider: "dashscope"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "qwen-turbo"
|
||||
|
||||
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
||||
filePrefix: "http://116.204.74.41:9000"
|
||||
|
||||
gmq:
|
||||
redis:
|
||||
primary:
|
||||
addr: "116.204.74.41"
|
||||
port: "6379"
|
||||
db: 0
|
||||
username: ""
|
||||
password: ""
|
||||
poolSize: 10
|
||||
minIdleConn: 5
|
||||
maxActiveConn: 10
|
||||
maxRetries: 30
|
||||
|
||||
# Meilisearch 全文检索配置
|
||||
meilisearch:
|
||||
default:
|
||||
host: "http://localhost"
|
||||
port: 7700
|
||||
apiKey: "admin"
|
||||
# apiKey: "6b8b6062bcb5e31f150427961d9da1a9e81758aa"
|
||||
|
||||
cache:
|
||||
localTTL: 60
|
||||
redisTTL: 300
|
||||
149
config-master.yml
Normal file
149
config-master.yml
Normal file
@@ -0,0 +1,149 @@
|
||||
server:
|
||||
address: :3006
|
||||
name: rag
|
||||
workerId: 1
|
||||
|
||||
# Database.
|
||||
database:
|
||||
default:
|
||||
- type: "pgsql"
|
||||
host: "192.168.0.169"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "rag"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
- type: "pgsql"
|
||||
host: "192.168.0.169"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "slave" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: false # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_knowledge:
|
||||
- type: "pgsql"
|
||||
host: "192.168.0.169"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master"
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_vector:
|
||||
- type: "pgsql"
|
||||
host: "192.168.0.169"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_vector_" # (可选)表名前缀
|
||||
role: "master"
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: "192.168.0.169:6379"
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: 192.168.0.169:8500
|
||||
|
||||
jaeger:
|
||||
addr: 192.168.0.169:4318
|
||||
|
||||
# eino框架配置
|
||||
eino:
|
||||
# 文件切分配置
|
||||
splitter:
|
||||
bufferSize: 1
|
||||
minChunkSize: 64
|
||||
percentile: 0.75
|
||||
# 向量化配置
|
||||
embedding:
|
||||
provider: "dashscope"
|
||||
# apiKey: "d158d896-8c54-40ee-9d61-4c5d37cd545c"
|
||||
# model: "ep-20260326123502-khmdq"
|
||||
# apiType: "multi_modal_api"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "text-embedding-v3"
|
||||
chatmodel:
|
||||
provider: "dashscope"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "qwen-turbo"
|
||||
|
||||
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
||||
filePrefix: "http://192.168.0.169:9000"
|
||||
|
||||
gmq:
|
||||
redis:
|
||||
primary:
|
||||
addr: "192.168.0.169"
|
||||
port: "6379"
|
||||
db: 0
|
||||
username: ""
|
||||
password: ""
|
||||
poolSize: 10
|
||||
minIdleConn: 5
|
||||
maxActiveConn: 10
|
||||
maxRetries: 30
|
||||
|
||||
# Meilisearch 全文检索配置
|
||||
meilisearch:
|
||||
default:
|
||||
host: "http://localhost"
|
||||
port: 7700
|
||||
apiKey: "admin"
|
||||
# apiKey: "6b8b6062bcb5e31f150427961d9da1a9e81758aa"
|
||||
|
||||
cache:
|
||||
localTTL: 60
|
||||
redisTTL: 300
|
||||
51
config.yml
51
config.yml
@@ -12,8 +12,9 @@ database:
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "rag"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: false # (可选)开启调试模式
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
@@ -30,7 +31,8 @@ database:
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "rag"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "slave" # (可选)数据库主从角色(master/slave),默认为master。如果不使用应用主从机制请不配置或留空即可。
|
||||
debug: false # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
@@ -44,15 +46,36 @@ database:
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
tenant-1:
|
||||
rag_knowledge:
|
||||
- type: "pgsql"
|
||||
host: "localhost"
|
||||
port: "5432"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "123456"
|
||||
name: "tenant"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_knowledge_" # (可选)表名前缀
|
||||
role: "master"
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
|
||||
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
|
||||
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
|
||||
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
|
||||
maxIdleConnTime: "30s" # (可选,v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置,避免长时间空闲连接占用资源。
|
||||
createdAt: "created_at" # (可选)自动创建时间字段名称
|
||||
updatedAt: "updated_at" # (可选)自动更新时间字段名称
|
||||
deletedAt: "deleted_at" # (可选)软删除时间字段名称
|
||||
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效
|
||||
rag_vector:
|
||||
- type: "pgsql"
|
||||
host: "116.204.74.41"
|
||||
port: "15432"
|
||||
user: "postgres"
|
||||
pass: "Bjang09@686^*^"
|
||||
name: "tenant-1"
|
||||
prefix: "rag_vector_" # (可选)表名前缀
|
||||
role: "master"
|
||||
prefix: "rag_" # (可选)表名前缀
|
||||
debug: true # (可选)开启调试模式
|
||||
dryRun: false # (可选)ORM空跑(只读不写)
|
||||
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312),一般设置为utf8mb4。默认为utf8。
|
||||
@@ -68,14 +91,14 @@ database:
|
||||
|
||||
redis:
|
||||
default:
|
||||
address: "localhost:6379"
|
||||
address: "116.204.74.41:6379"
|
||||
db: 0
|
||||
|
||||
consul:
|
||||
address: localhost:8500
|
||||
address: 116.204.74.41:8500
|
||||
|
||||
jaeger:
|
||||
addr: localhost:4318
|
||||
addr: 116.204.74.41:4318
|
||||
|
||||
# eino框架配置
|
||||
eino:
|
||||
@@ -92,6 +115,10 @@ eino:
|
||||
# apiType: "multi_modal_api"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "text-embedding-v3"
|
||||
chatmodel:
|
||||
provider: "dashscope"
|
||||
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
|
||||
model: "qwen-turbo"
|
||||
|
||||
# 文件上传服务地址,与oss模块minio中的endpoint一致
|
||||
filePrefix: "http://116.204.74.41:9000"
|
||||
@@ -99,7 +126,7 @@ filePrefix: "http://116.204.74.41:9000"
|
||||
gmq:
|
||||
redis:
|
||||
primary:
|
||||
addr: "localhost"
|
||||
addr: "116.204.74.41"
|
||||
port: "6379"
|
||||
db: 0
|
||||
username: ""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package public
|
||||
|
||||
const KnowledgeLockEsKey = "rag:knowledge:lock:knowledgeIdEs-%v"
|
||||
const KnowledgeLockSqlKey = "rag:knowledge:lock:knowledgeIdSql-%v"
|
||||
const KnowledgeContentHashEsKey = "rag:knowledge:knowledgeId:contentHashEs-%v"
|
||||
const KnowledgeContentHashSqlKey = "rag:knowledge:knowledgeId:contentHashSql-%v"
|
||||
const KnowledgeLockEsKey = "rag_binary:knowledge:lock:knowledgeIdEs-%v"
|
||||
const KnowledgeLockSqlKey = "rag_binary:knowledge:lock:knowledgeIdSql-%v"
|
||||
const KnowledgeContentHashEsKey = "rag_binary:knowledge:knowledgeId:contentHashEs-%v"
|
||||
const KnowledgeContentHashSqlKey = "rag_binary:knowledge:knowledgeId:contentHashSql-%v"
|
||||
|
||||
const (
|
||||
KnowledgeDocumentVectorStatusTopic = "knowledge:document:vector:status:stream"
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
package public
|
||||
|
||||
// 数据库名称
|
||||
const (
|
||||
DbNameKnowledge = "rag_knowledge"
|
||||
DbNameVector = "rag_vector"
|
||||
)
|
||||
|
||||
// sql 数据库表名
|
||||
const (
|
||||
TableNameDocument = "document"
|
||||
TableNameDataset = "dataset"
|
||||
TableNameKeyword = "keyword"
|
||||
TableNameTask = "task"
|
||||
TableNameDatasetIndex = "dataset_index"
|
||||
TableNameDocumentChunk = "document_chunk"
|
||||
)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package controller
|
||||
|
||||
type datasetIndex struct{}
|
||||
|
||||
var DatasetIndex = new(datasetIndex)
|
||||
@@ -48,7 +48,7 @@ func (c *document) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto
|
||||
}
|
||||
|
||||
// Process 处理文件(向量化)
|
||||
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
res, err = service.Document.Process(ctx, req)
|
||||
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *beans.ResponseEmpty, err error) {
|
||||
err = service.Document.Process(ctx, req)
|
||||
return
|
||||
}
|
||||
|
||||
17
controller/rag_query.go
Normal file
17
controller/rag_query.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/model/dto"
|
||||
"rag/service"
|
||||
)
|
||||
|
||||
type ragQuery struct{}
|
||||
|
||||
var RAGQuery = new(ragQuery)
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (c *ragQuery) Query(ctx context.Context, req *dto.RAGQueryReq) (res *dto.RAGQueryRes, err error) {
|
||||
res, err = service.RAGQuery.Query(ctx, req)
|
||||
return
|
||||
}
|
||||
@@ -22,7 +22,7 @@ func (d *datasetDao) Insert(ctx context.Context, req *dto.CreateDatasetReq) (id
|
||||
if err = gconv.Struct(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDataset).Data(&res).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDataset).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func (d *datasetDao) Insert(ctx context.Context, req *dto.CreateDatasetReq) (id
|
||||
|
||||
// Update 更新数据集
|
||||
func (d *datasetDao) Update(ctx context.Context, req *dto.UpdateDatasetReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDataset).OmitEmpty()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDataset).OmitEmpty()
|
||||
if !g.IsEmpty(req.DocumentCount) {
|
||||
model.Data(entity.DatasetCol.DocumentCount, &gdb.Counter{
|
||||
Field: entity.DatasetCol.DocumentCount,
|
||||
@@ -53,7 +53,7 @@ func (d *datasetDao) Update(ctx context.Context, req *dto.UpdateDatasetReq) (row
|
||||
|
||||
// Delete 删除数据集
|
||||
func (d *datasetDao) Delete(ctx context.Context, req *dto.DeleteDatasetReq) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDataset).Where(entity.DatasetCol.Id, req.Id).Delete()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDataset).Where(entity.DatasetCol.Id, req.Id).Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (d *datasetDao) Delete(ctx context.Context, req *dto.DeleteDatasetReq) (row
|
||||
}
|
||||
|
||||
func (d *datasetDao) GetByID(ctx context.Context, req *dto.GetDatasetReq, fields ...string) (res *entity.Dataset, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDataset).Where(entity.DatasetCol.Id, req.Id).Fields(fields).One()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDataset).Where(entity.DatasetCol.Id, req.Id).Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func (d *datasetDao) GetByID(ctx context.Context, req *dto.GetDatasetReq, fields
|
||||
|
||||
// List 获取数据集列表
|
||||
func (d *datasetDao) List(ctx context.Context, req *dto.ListDatasetReq, fields ...string) (res []*entity.Dataset, total int, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDataset).Fields(fields).OmitEmpty()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDataset).Fields(fields).OmitEmpty()
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.DatasetCol.Name, "%"+req.Keyword+"%")
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type datasetIndexDao struct{}
|
||||
|
||||
// Insert 插入数据集索引
|
||||
func (d *datasetIndexDao) Insert(ctx context.Context, index *entity.DatasetIndex) (id int64, err error) {
|
||||
_, err = gfdb.DB(ctx).Model(ctx, public.TableNameDatasetIndex).Data(index).Insert()
|
||||
_, err = gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDatasetIndex).Data(index).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -25,7 +25,7 @@ func (d *datasetIndexDao) Insert(ctx context.Context, index *entity.DatasetIndex
|
||||
|
||||
// GetByDatasetId 根据数据集ID获取索引
|
||||
func (d *datasetIndexDao) GetByDatasetId(ctx context.Context, datasetId int64) (result *entity.DatasetIndex, err error) {
|
||||
err = gfdb.DB(ctx).Model(ctx, public.TableNameDatasetIndex).Where(entity.DatasetIndexCol.DatasetId, datasetId).Scan(&result)
|
||||
err = gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDatasetIndex).Where(entity.DatasetIndexCol.DatasetId, datasetId).Scan(&result)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -37,23 +37,21 @@ func (d *datasetIndexDao) GetByDatasetId(ctx context.Context, datasetId int64) (
|
||||
|
||||
// IncVectorCount 增加或减少向量数量
|
||||
func (d *datasetIndexDao) IncVectorCount(ctx context.Context, id int64, delta int64) (err error) {
|
||||
_, err = gfdb.DB(ctx).Model(ctx, public.TableNameDatasetIndex).
|
||||
_, err = gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDatasetIndex).
|
||||
Where(entity.DatasetIndexCol.Id, id).
|
||||
Increment(entity.DatasetIndexCol.VectorCount, delta)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *datasetIndexDao) InsertIndex(ctx context.Context, indexName string) (err error) {
|
||||
prefix, err := gfdb.GetTablePrefix(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
db := gfdb.DB(ctx, public.DbNameVector)
|
||||
sqlStr := fmt.Sprintf(`
|
||||
CREATE INDEX IF NOT EXISTS %s
|
||||
ON %s
|
||||
USING ivfflat (vector vector_cosine_ops)
|
||||
WITH (lists = 100)
|
||||
WHERE vector IS NOT NULL;
|
||||
`, indexName, prefix+public.TableNameDocumentChunk)
|
||||
_, err = gfdb.DB(ctx).Exec(ctx, sqlStr)
|
||||
`, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk)
|
||||
_, err = db.Exec(ctx, sqlStr)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func (d *documentDao) Insert(ctx context.Context, req *dto.CreateDocumentReq) (i
|
||||
if err = gconv.Struct(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDocument).Data(&res).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func (d *documentDao) Insert(ctx context.Context, req *dto.CreateDocumentReq) (i
|
||||
|
||||
// Update 更新文件
|
||||
func (d *documentDao) Update(ctx context.Context, req *dto.UpdateDocumentReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDocument).OmitEmpty()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).OmitEmpty()
|
||||
if !g.IsEmpty(req.ChunkCount) {
|
||||
model.Data(entity.DocumentCol.ChunkCount, &gdb.Counter{
|
||||
Field: entity.DocumentCol.ChunkCount,
|
||||
@@ -48,7 +48,7 @@ func (d *documentDao) Update(ctx context.Context, req *dto.UpdateDocumentReq) (r
|
||||
|
||||
// Delete 删除文件
|
||||
func (d *documentDao) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Delete()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (d *documentDao) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (r
|
||||
|
||||
// GetByID 根据ID获取文件
|
||||
func (d *documentDao) GetByID(ctx context.Context, req *dto.GetDocumentReq, fields ...string) (res *entity.Document, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Fields(fields).One()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -67,7 +67,7 @@ func (d *documentDao) GetByID(ctx context.Context, req *dto.GetDocumentReq, fiel
|
||||
|
||||
// List 获取文件列表
|
||||
func (d *documentDao) List(ctx context.Context, req *dto.ListDocumentReq, fields ...string) (res []*entity.Document, total int, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDocument).OmitEmpty()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).OmitEmpty()
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.DocumentCol.Title, "%"+req.Keyword+"%")
|
||||
}
|
||||
|
||||
@@ -2,12 +2,17 @@ package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/consts/public"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
var DocumentChunk = new(documentChunkDao)
|
||||
@@ -20,7 +25,7 @@ func (d *documentChunkDao) BatchInsert(ctx context.Context, req []*dto.VectorDoc
|
||||
if err = gconv.Structs(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk).Data(&res).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -29,7 +34,7 @@ func (d *documentChunkDao) BatchInsert(ctx context.Context, req []*dto.VectorDoc
|
||||
|
||||
// Update 更新文件块
|
||||
func (d *documentChunkDao) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk)
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk)
|
||||
r, err := model.Data(&req).Where(entity.DocumentChunkCol.Id, req.Id).Update()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -39,7 +44,7 @@ func (d *documentChunkDao) Update(ctx context.Context, req *dto.UpdateDocumentCh
|
||||
|
||||
// List 文件块列表
|
||||
func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkReq, fields ...string) (res []*entity.DocumentChunk, total int, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameDocumentChunk).Fields(fields).OmitEmpty().
|
||||
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Fields(fields).OmitEmpty().
|
||||
Where(entity.DocumentChunkCol.DatasetId, req.DatasetId).
|
||||
Where(entity.DocumentChunkCol.DocumentId, req.DocumentId).
|
||||
Where(entity.DocumentChunkCol.Status, req.Status).
|
||||
@@ -56,49 +61,56 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR
|
||||
return
|
||||
}
|
||||
|
||||
//// Insert 插入向量文档
|
||||
//func (d *vectorDocumentDao) Insert(ctx context.Context, docs []*entity.DocumentChunk) (ids []interface{}, err error) {
|
||||
// if len(docs) == 0 {
|
||||
// return
|
||||
// }
|
||||
// interfaces := make([]interface{}, len(docs))
|
||||
// for i := range docs {
|
||||
// interfaces[i] = docs[i]
|
||||
// }
|
||||
// return mongoDB.Insert(ctx, interfaces, CollectionVectorDoc)
|
||||
//}
|
||||
//
|
||||
//// DeleteByIDs 根据ID删除向量文档
|
||||
//func (d *vectorDocumentDao) DeleteByIDs(ctx context.Context, ids []string) (err error) {
|
||||
// if len(ids) == 0 {
|
||||
// return
|
||||
// }
|
||||
// objectIDs := make([]bson.ObjectID, len(ids))
|
||||
// for i, id := range ids {
|
||||
// objectIDs[i], err = bson.ObjectIDFromHex(id)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// filter := bson.M{"_id": bson.M{"$in": objectIDs}}
|
||||
// _, err = mongoDB.Delete(ctx, filter, CollectionVectorDoc)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//// GetByIndexID 根据索引ID获取向量文档
|
||||
//func (d *vectorDocumentDao) GetByIndexID(ctx context.Context, indexID string, limit int) (result []*entity.DocumentChunk, err error) {
|
||||
// filter := bson.M{"indexId": indexID}
|
||||
// page := &beans.Page{PageNum: 1, PageSize: int64(limit)}
|
||||
// _, err = mongoDB.Find(ctx, filter, &result, CollectionVectorDoc, page, nil)
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//// GetByVectorIDs 根据向量ID获取向量文档
|
||||
//func (d *vectorDocumentDao) GetByVectorIDs(ctx context.Context, vectorIDs []string) (result []*entity.DocumentChunk, err error) {
|
||||
// if len(vectorIDs) == 0 {
|
||||
// return
|
||||
// }
|
||||
// filter := bson.M{"vectorId": bson.M{"$in": vectorIDs}}
|
||||
// _, err = mongoDB.Find(ctx, filter, &result, CollectionVectorDoc, &beans.Page{PageSize: -1}, nil)
|
||||
// return
|
||||
//}
|
||||
func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) {
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <=> ? AS distance
|
||||
FROM rag_vector_document_chunk
|
||||
WHERE dataset_id IN (?)
|
||||
AND vector IS NOT NULL
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
// 顺序:vector, dataset_id, topK
|
||||
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryVec, datasetId, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result.List(), nil
|
||||
}
|
||||
|
||||
// SearchByKeywords 通过关键词全文检索文档块
|
||||
func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) {
|
||||
// 构建 meilisearch 查询参数
|
||||
searchParams := &meilisearch.SearchParams{
|
||||
Query: query,
|
||||
Limit: int64(topK),
|
||||
ShowRankingScore: true,
|
||||
}
|
||||
|
||||
// 构建 datasetIds 过滤条件
|
||||
if len(datasetIds) > 0 {
|
||||
datasetIdStrs := gconv.Strings(datasetIds)
|
||||
quotedIds := make([]string, len(datasetIdStrs))
|
||||
for i, id := range datasetIdStrs {
|
||||
quotedIds[i] = fmt.Sprintf("%s", id)
|
||||
}
|
||||
searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds))
|
||||
}
|
||||
|
||||
// 执行搜索
|
||||
var hits []map[string]interface{}
|
||||
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换查询结果为 gdb.List
|
||||
resultList := make(gdb.List, 0, len(hits))
|
||||
for _, hit := range hits {
|
||||
resultList = append(resultList, hit)
|
||||
}
|
||||
|
||||
return resultList, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func (d *keywordDao) Insert(ctx context.Context, req *dto.CreateKeywordReq) (id
|
||||
if err = gconv.Struct(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).Data(&res).Insert()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (d *keywordDao) BatchSaveOrUpdate(ctx context.Context, req []*dto.CreateKey
|
||||
if err = gconv.Structs(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).Data(&res).OnConflict(
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Data(&res).OnConflict(
|
||||
entity.KeywordCol.TenantId,
|
||||
entity.KeywordCol.DatasetId,
|
||||
entity.KeywordCol.DocumentId,
|
||||
@@ -44,7 +44,7 @@ func (d *keywordDao) BatchSaveOrUpdate(ctx context.Context, req []*dto.CreateKey
|
||||
}
|
||||
|
||||
func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword)
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword)
|
||||
r, err := model.Data(&req).Where(entity.KeywordCol.Id, req.Id).Update()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -53,7 +53,7 @@ func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (row
|
||||
}
|
||||
|
||||
func (d *keywordDao) Delete(ctx context.Context, req *dto.DeleteKeywordReq) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).Where(entity.KeywordCol.Id, req.Id).Delete()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Where(entity.KeywordCol.Id, req.Id).Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -61,7 +61,7 @@ func (d *keywordDao) Delete(ctx context.Context, req *dto.DeleteKeywordReq) (row
|
||||
}
|
||||
|
||||
func (d *keywordDao) Count(ctx context.Context, req *dto.ListKeywordReq) (count int, err error) {
|
||||
count, err = gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).OmitEmpty().
|
||||
count, err = gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).OmitEmpty().
|
||||
Where(entity.KeywordCol.DatasetId, req.DatasetId).
|
||||
Where(entity.KeywordCol.DocumentId, req.DocumentId).
|
||||
Where(entity.KeywordCol.Word, req.Word).Count()
|
||||
@@ -69,7 +69,7 @@ func (d *keywordDao) Count(ctx context.Context, req *dto.ListKeywordReq) (count
|
||||
}
|
||||
|
||||
func (d *keywordDao) GetByID(ctx context.Context, req *dto.GetKeywordReq, fields ...string) (res *entity.Document, err error) {
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).Where(entity.KeywordCol.Id, req.Id).Fields(fields).One()
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Where(entity.KeywordCol.Id, req.Id).Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -78,10 +78,13 @@ func (d *keywordDao) GetByID(ctx context.Context, req *dto.GetKeywordReq, fields
|
||||
}
|
||||
|
||||
func (d *keywordDao) List(ctx context.Context, req *dto.ListKeywordReq, fields ...string) (res []*entity.Keyword, total int, err error) {
|
||||
model := gfdb.DB(ctx).Model(ctx, public.TableNameKeyword).Fields(fields).OmitEmpty()
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Fields(fields).OmitEmpty()
|
||||
if !g.IsEmpty(req.Keyword) {
|
||||
model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%")
|
||||
}
|
||||
model.WhereIn(entity.KeywordCol.Word, req.Words)
|
||||
model.Where(entity.KeywordCol.DatasetId, req.DatasetId)
|
||||
model.Where(entity.KeywordCol.DocumentId, req.DocumentId)
|
||||
model.OrderDesc(entity.KeywordCol.Weight)
|
||||
model.OrderDesc(entity.KeywordCol.CreatedAt)
|
||||
if req.Page != nil {
|
||||
|
||||
58
dao/task.go
Normal file
58
dao/task.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/consts/public"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = new(taskDao)
|
||||
|
||||
type taskDao struct{}
|
||||
|
||||
// Insert 创建任务
|
||||
func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, err error) {
|
||||
var res *entity.Task
|
||||
if err = gconv.Struct(req, &res); err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Data(&res).Insert()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新任务
|
||||
func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask)
|
||||
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
func (d *taskDao) Get(ctx context.Context, req *dto.GetTaskReq) (res []*entity.Task, total int, err error) {
|
||||
r, total, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty().
|
||||
Where(entity.TaskCol.Id, req.Id).
|
||||
Where(entity.TaskCol.TaskId, req.TaskId).
|
||||
Where(entity.TaskCol.TaskType, req.TaskType).AllAndCount(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Structs(&res)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *taskDao) DeleteByTaskId(ctx context.Context, req *dto.DeleteTaskByTaskIdReq) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Where(entity.TaskCol.TaskId, req.TaskId).Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
35
go.mod
35
go.mod
@@ -3,15 +3,26 @@ module rag
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
gitea.com/red-future/common v0.0.6
|
||||
gitea.com/red-future/common v0.0.11
|
||||
github.com/bjang03/gmq v0.0.0-00010101000000-000000000000
|
||||
github.com/cloudwego/eino v0.8.6
|
||||
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1
|
||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
|
||||
github.com/gogf/gf/v2 v2.10.0
|
||||
github.com/golang/glog v1.2.5
|
||||
github.com/pgvector/pgvector-go v0.3.0
|
||||
)
|
||||
|
||||
replace gitea.com/red-future/common v0.0.6 => ../common
|
||||
replace gitea.com/red-future/common v0.0.11 => ../common
|
||||
|
||||
replace github.com/bjang03/gmq => ../gmq
|
||||
|
||||
@@ -35,19 +46,8 @@ require (
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1 // indirect
|
||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 // indirect
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 // indirect
|
||||
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
|
||||
github.com/dgraph-io/ristretto v0.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
@@ -55,8 +55,6 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/eino-contrib/docx2md v0.0.1 // indirect
|
||||
github.com/eino-contrib/jsonschema v1.0.3 // indirect
|
||||
github.com/elastic/elastic-transport-go/v8 v8.10.0 // indirect
|
||||
github.com/elastic/go-elasticsearch/v8 v8.16.0 // indirect
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/fatih/color v1.19.0 // indirect
|
||||
@@ -74,7 +72,6 @@ require (
|
||||
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
|
||||
github.com/golang/glog v1.2.5 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
@@ -105,7 +102,7 @@ require (
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 // indirect
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2 // indirect
|
||||
github.com/meilisearch/meilisearch-go v0.36.1 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
@@ -134,7 +131,7 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199 // indirect
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181 // indirect
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect
|
||||
github.com/xuri/excelize/v2 v2.9.0 // indirect
|
||||
|
||||
23
go.sum
23
go.sum
@@ -154,12 +154,10 @@ github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-202603231123
|
||||
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419/go.mod h1:ekJmA+GLD9vJyZNeODZDBFMiJ92Suy6nF0OY42X3sao=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 h1:eM29lyMShtFZNoAhE5g96+zHg9PBLckRyd2HtVeeY4E=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419/go.mod h1:SajSFFRIXJXIbxadAAlSUIS5KTY8R/jzJg9RNSOXCCI=
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 h1:vZ3dL8xwo2sy73aBVKs4AJiO5OCHRxMOJUwIYkp0CWs=
|
||||
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:+oI0sr0rA0OHCxaQJ0rzMYld3LAODHhPKzBx5JYCya0=
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 h1:Sl6giB1SJlA+ZlO0gzPH05IsUORtdYYPN6GiyH1B9MA=
|
||||
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:H4kNmiTe2irnvipVNIP4q8yqXf2fZ6v24krvQYBtYb8=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 h1:yOZII6VYaL00CVZYba+HUixFygsW0Xz/1QjQ5htj1Ls=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14/go.mod h1:1xMQZ8eE11pkEoTAEy8UlaAY817qGVMvjpDPGSIO3Ns=
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7 h1:8c1LB5lH+dERbf2twp18B1Y822JOQSsS6x7Vnksehk0=
|
||||
github.com/cloudwego/eino-ext/components/model/qwen v0.1.7/go.mod h1:n4iuIUQeL3D8GRsGAhkgceRZpoyPQbqOXFMXM2Q4hNY=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 h1:LbdSG9+qWzzp9RFW6dSFkaUW171JvCoYn/K63zX6dQE=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0=
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
@@ -193,10 +191,6 @@ github.com/eino-contrib/docx2md v0.0.1 h1:Clz0sF8jiQRYAIZAUTuTAjh0vF/1KqHQqsMha1
|
||||
github.com/eino-contrib/docx2md v0.0.1/go.mod h1:b1dupA9cF5yExHjVMCcP6feyE6mwZjsY7Cc9ESO5Y14=
|
||||
github.com/eino-contrib/jsonschema v1.0.3 h1:2Kfsm1xlMV0ssY2nuxshS4AwbLFuqmPmzIjLVJ1Fsp0=
|
||||
github.com/eino-contrib/jsonschema v1.0.3/go.mod h1:cpnX4SyKjWjGC7iN2EbhxaTdLqGjCi0e9DxpLYxddD4=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.10.0 h1:vzpe1BMLdShc7yWNV55U6aGk4UtYEOVsBJ5S4UIeY9Q=
|
||||
github.com/elastic/elastic-transport-go/v8 v8.10.0/go.mod h1:KB6jblnx4NnImxHKULFys7VQ472Av8uzrbkr6OtbOp8=
|
||||
github.com/elastic/go-elasticsearch/v8 v8.16.0 h1:f7bR+iBz8GTAVhwyFO3hm4ixsz2eMaEy0QroYnXV3jE=
|
||||
github.com/elastic/go-elasticsearch/v8 v8.16.0/go.mod h1:lGMlgKIbYoRvay3xWBeKahAiJOgmFDsjZC39nmO3H64=
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -529,8 +523,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1 h1:u/IMMgrj/d617Dh/8BKAwlcstD74ynOJzCtVl+y8xAs=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.1/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2 h1:iXombGGjqjBrmE9WaSidUhhi3YQhf42QTHvHLMkgvCA=
|
||||
github.com/meguminnnnnnnnn/go-openai v0.1.2/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY=
|
||||
github.com/meilisearch/meilisearch-go v0.36.1 h1:mJTCJE5g7tRvaqKco6DfqOuJEjX+rRltDEnkEC02Y0M=
|
||||
github.com/meilisearch/meilisearch-go v0.36.1/go.mod h1:hWcR0MuWLSzHfbz9GGzIr3s9rnXLm1jqkmHkJPbUSvM=
|
||||
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4=
|
||||
@@ -733,8 +727,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199 h1:zv9QOqTl/IsLwtfC37GlJtcz6vMAHi+pjq8ILWjLYUc=
|
||||
github.com/volcengine/volc-sdk-golang v1.0.199/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ=
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181 h1:/3PB4M1N4fjMqiSKTJwX43EZ5Nn1HUOtQrSCk+22+wI=
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.181/go.mod h1:gfEDc1s7SYaGoY+WH2dRrS3qiuDJMkwqyfXWCa7+7oA=
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9 h1:du2gnImtyWXKkQFnJW/GXCs+UBibGGOXIbP1Ams2pB8=
|
||||
github.com/volcengine/volcengine-go-sdk v1.2.9/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg=
|
||||
@@ -1191,6 +1185,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
|
||||
21
main.go
21
main.go
@@ -9,8 +9,10 @@ import (
|
||||
"rag/service"
|
||||
"syscall"
|
||||
|
||||
_ "gitea.com/red-future/common/config"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/jaeger"
|
||||
"gitea.com/red-future/common/utils"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
@@ -27,22 +29,17 @@ func main() {
|
||||
controller.Dataset,
|
||||
controller.Document,
|
||||
controller.DocumentChunk,
|
||||
controller.Keyword,
|
||||
controller.RAGQuery,
|
||||
})
|
||||
|
||||
gmq.Init("config.yml")
|
||||
|
||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
SubMessage: types.SubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
ConsumerName: public.KnowledgeDocumentVectorStatusConsumer,
|
||||
AutoAck: public.KnowledgeDocumentVectorStatusAutoAck,
|
||||
FetchCount: public.KnowledgeDocumentVectorStatusBatchSize,
|
||||
HandleFunc: service.Document.DocsVectorStatusMsg,
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
err := utils.InitGseTool(ctx)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "gse 分词工具初始化失败:", err)
|
||||
}
|
||||
|
||||
gmq.Init("config.yml")
|
||||
|
||||
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
|
||||
SubMessage: types.SubMessage{
|
||||
Topic: public.KnowledgeDocumentChunkTopic,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
package dto
|
||||
@@ -84,12 +84,6 @@ type ProcessDocumentReq struct {
|
||||
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
|
||||
}
|
||||
|
||||
// ProcessDocumentRes 处理文件响应
|
||||
type ProcessDocumentRes struct {
|
||||
ChunkCount int64 `json:"chunkCount"`
|
||||
CostTime int64 `json:"costTime"`
|
||||
}
|
||||
|
||||
type ListDocumentChunkRPC struct {
|
||||
List []*DocumentChunkRPC `json:"list"`
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ type ListKeywordReq struct {
|
||||
DatasetId int64 `json:"datasetId"`
|
||||
DocumentId int64 `json:"documentId"`
|
||||
Word string `json:"word"`
|
||||
Words []string `json:"words"`
|
||||
Keyword string `json:"keyword" dc:"关键词搜索"`
|
||||
}
|
||||
|
||||
@@ -62,9 +63,11 @@ type ListKeywordRes struct {
|
||||
}
|
||||
|
||||
type KeywordVO struct {
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
Id int64 `json:"id,string" dc:"id"`
|
||||
Word string `json:"word" dc:"关键词名称"`
|
||||
Weight int16 `json:"weight" dc:"权重"`
|
||||
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
|
||||
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
|
||||
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
|
||||
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
|
||||
25
model/dto/rag_query.go
Normal file
25
model/dto/rag_query.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// RAGQueryReq RAG查询请求
|
||||
type RAGQueryReq struct {
|
||||
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
|
||||
|
||||
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
|
||||
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
|
||||
History []*Message `json:"history" dc:"历史对话"`
|
||||
TopK int `json:"topK" d:"5" dc:"检索topK,默认5"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// RAGQueryRes RAG查询响应
|
||||
type RAGQueryRes struct {
|
||||
Answer string `json:"answer" dc:"生成的答案"`
|
||||
}
|
||||
65
model/dto/task.go
Normal file
65
model/dto/task.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"rag/common/task"
|
||||
)
|
||||
|
||||
// WriteTaskProgressReq 写入任务进度请求
|
||||
type WriteTaskProgressReq struct {
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// CreateTaskReq 创建任务请求
|
||||
type CreateTaskReq struct {
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// UpdateTaskReq 更新任务请求
|
||||
type UpdateTaskReq struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
}
|
||||
|
||||
// DeleteTaskByTaskIdReq 删除任务请求
|
||||
type DeleteTaskByTaskIdReq struct {
|
||||
TaskId int64 `json:"taskId" v:"required#任务id不能为空"`
|
||||
}
|
||||
|
||||
// GetTaskReq 获取任务请求
|
||||
type GetTaskReq struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskId int64 `json:"taskId" dc:"任务ID"`
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
}
|
||||
|
||||
// TaskVO 任务视图对象
|
||||
type TaskVO struct {
|
||||
Id int64 `json:"id" dc:"任务ID"`
|
||||
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `json:"status" dc:"任务状态"`
|
||||
Priority task.TaskPriority `json:"priority" dc:"任务优先级"`
|
||||
ParentTaskID int64 `json:"parentTaskId" dc:"父任务ID"`
|
||||
TotalItems int64 `json:"totalItems" dc:"总项数"`
|
||||
ProcessedItems int64 `json:"processedItems" dc:"已处理项数"`
|
||||
Progress float64 `json:"progress" dc:"进度百分比"`
|
||||
StartTime *int64 `json:"startTime" dc:"开始时间戳"`
|
||||
EndTime *int64 `json:"endTime" dc:"结束时间戳"`
|
||||
Duration int64 `json:"duration" dc:"耗时(毫秒)"`
|
||||
SuccessCount int64 `json:"successCount" dc:"成功数"`
|
||||
FailCount int64 `json:"failCount" dc:"失败数"`
|
||||
Executor string `json:"executor" dc:"执行器"`
|
||||
DocumentID int64 `json:"documentId" dc:"文档ID"`
|
||||
Remark string `json:"remark" dc:"备注"`
|
||||
Creator string `json:"creator" dc:"创建人"`
|
||||
CreatedAt int64 `json:"createdAt" dc:"创建时间"`
|
||||
Updater string `json:"updater" dc:"更新人"`
|
||||
UpdatedAt int64 `json:"updatedAt" dc:"更新时间"`
|
||||
}
|
||||
66
model/entity/task.go
Normal file
66
model/entity/task.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"rag/common/task"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
type taskCol struct {
|
||||
beans.SQLBaseCol
|
||||
TaskId string
|
||||
TaskType string
|
||||
Status string
|
||||
Executor string
|
||||
Remark string
|
||||
//Priority string
|
||||
//ParentTaskId string
|
||||
//TotalItems string
|
||||
//ProcessedItems string
|
||||
//Progress string
|
||||
//StartTime string
|
||||
//EndTime string
|
||||
//Duration string
|
||||
//SuccessCount string
|
||||
//FailCount string
|
||||
}
|
||||
|
||||
var TaskCol = taskCol{
|
||||
SQLBaseCol: beans.DefSQLBaseCol,
|
||||
TaskId: "task_id",
|
||||
TaskType: "task_type",
|
||||
Status: "status",
|
||||
Executor: "executor",
|
||||
Remark: "remark",
|
||||
//Priority: "priority",
|
||||
//ParentTaskId: "parent_task_id",
|
||||
//TotalItems: "total_items",
|
||||
//ProcessedItems: "processed_items",
|
||||
//Progress: "progress",
|
||||
//StartTime: "start_time",
|
||||
//EndTime: "end_time",
|
||||
//Duration: "duration",
|
||||
//SuccessCount: "success_count",
|
||||
//FailCount: "fail_count",
|
||||
}
|
||||
|
||||
// Task 任务记录表
|
||||
type Task struct {
|
||||
beans.SQLBaseDO `orm:",inline"`
|
||||
|
||||
TaskId int64 `orm:"task_id" json:"taskId" dc:"任务ID"`
|
||||
TaskType task.TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
|
||||
Status task.TaskStatus `orm:"status" json:"status" dc:"任务状态"`
|
||||
Executor string `orm:"executor" json:"executor" dc:"执行器"`
|
||||
Remark string `orm:"remark" json:"remark" dc:"备注"`
|
||||
//Priority task.TaskPriority `orm:"priority" json:"priority" dc:"任务优先级"`
|
||||
//ParentTaskId int64 `orm:"parent_task_id" json:"parentTaskId" dc:"父任务ID"`
|
||||
//TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总项数"`
|
||||
//ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理项数"`
|
||||
//SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
|
||||
//FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
|
||||
//Progress float64 `orm:"progress" json:"progress" dc:"进度百分比"`
|
||||
//StartTime *gtime.Time `orm:"start_time" json:"startTime" dc:"开始时间戳"`
|
||||
//EndTime *gtime.Time `orm:"end_time" json:"endTime" dc:"结束时间戳"`
|
||||
//Duration int64 `orm:"duration" json:"duration" dc:"耗时(毫秒)"`
|
||||
}
|
||||
BIN
rag_binary
Executable file
BIN
rag_binary
Executable file
Binary file not shown.
@@ -1,5 +0,0 @@
|
||||
package service
|
||||
|
||||
var DatasetIndex = new(datasetIndexService)
|
||||
|
||||
type datasetIndexService struct{}
|
||||
@@ -2,22 +2,20 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/common/task"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"gitea.com/red-future/common/full-text-search/meilisearch"
|
||||
"gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/rag/eino"
|
||||
"gitea.com/red-future/common/rag/gse"
|
||||
"gitea.com/red-future/common/utils"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
@@ -28,6 +26,7 @@ import (
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -53,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
|
||||
return
|
||||
}
|
||||
res = &dto.CreateDocumentRes{Id: id}
|
||||
|
||||
// 写入任务进度待处理 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusPending,
|
||||
Remark: "文档上传成功待解析: " + req.Title,
|
||||
})
|
||||
return
|
||||
})
|
||||
|
||||
@@ -78,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
|
||||
DocumentCount: -1,
|
||||
DocumentSize: -docs.FileSize,
|
||||
}
|
||||
_, err = dao.Dataset.Update(ctx, datasetReq)
|
||||
if err != nil {
|
||||
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
|
||||
return
|
||||
}
|
||||
_, err = dao.Document.Delete(ctx, req)
|
||||
|
||||
if _, err = dao.Document.Delete(ctx, req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
|
||||
TaskId: docs.Id,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
@@ -106,115 +120,163 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
|
||||
Total: total,
|
||||
}
|
||||
err = gconv.Struct(list, &res.List)
|
||||
|
||||
//eino.TestIndexer()
|
||||
//eino.TestRetriever()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Process 处理文件(使用eino框架切分和向量化)
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
|
||||
startTime := time.Now()
|
||||
|
||||
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) {
|
||||
// 1. 查询文件信息
|
||||
documentReq := dto.GetDocumentReq{Id: req.Id}
|
||||
doc, err := dao.Document.GetByID(ctx, &documentReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
if g.IsEmpty(doc) {
|
||||
return errors.New("document not found")
|
||||
}
|
||||
|
||||
// 2. 使用eino框架进行文件切分(并发执行)
|
||||
var vectorDocsCount, chunks int64
|
||||
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
|
||||
var err1, err2, err3 error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(3)
|
||||
|
||||
// 任务1
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务2
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err2 = s.esSplitDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 任务3
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err3 = s.extractDocument(ctx, doc)
|
||||
}()
|
||||
|
||||
// 直接等待,不使用通道,避免泄漏
|
||||
wg.Wait()
|
||||
|
||||
// 2. 更新文档状态为处理中
|
||||
updateDocumentReq := new(dto.UpdateDocumentReq)
|
||||
updateDocumentReq.Id = req.Id
|
||||
|
||||
// 统一判断错误
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
// 更新文档状态
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err1 != nil {
|
||||
return nil, err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, err3
|
||||
}
|
||||
|
||||
// 4. 更新文件状态为处理中和切分数量
|
||||
if vectorDocsCount > 0 {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
} else {
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
|
||||
}
|
||||
updateDocumentReq.ChunkCount = chunks
|
||||
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
|
||||
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
|
||||
// 写入任务进度失败 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "更新文档状态失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
costTime := time.Since(startTime).Milliseconds()
|
||||
|
||||
return &dto.ProcessDocumentRes{
|
||||
ChunkCount: chunks,
|
||||
CostTime: costTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
// 写入任务进度进行中 任务类型为文档解析
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: req.Id,
|
||||
TaskType: task.TaskTypeDocParse,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "文档解析开始",
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ======================
|
||||
// 核心:grpool + g.Try 最佳实践
|
||||
// ======================
|
||||
// 使用带超时的background context,避免HTTP请求完成后context被取消
|
||||
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
taskCtx = context.WithValue(taskCtx, "user", user)
|
||||
// 任务1: SQL 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
var words []gse.Keyword
|
||||
// 任务2: ES 切分文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
// 任务3: 提取文档
|
||||
grpool.Add(taskCtx, func(ctx context.Context) {
|
||||
g.TryCatch(ctx, func(ctx context.Context) {
|
||||
if innerErr := s.extractDocument(ctx, doc); innerErr != nil {
|
||||
cancel()
|
||||
}
|
||||
}, func(ctx context.Context, err error) {
|
||||
cancel()
|
||||
})
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractDocument 关键词提取(支持取消)
|
||||
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var words []utils.Keyword
|
||||
if len(docs[0].Content) < 500 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 4)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 4)
|
||||
} else if len(docs[0].Content) < 2000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 8)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 8)
|
||||
} else if len(docs[0].Content) < 5000 {
|
||||
words = gse.GseTool.Extract(docs[0].Content, 13)
|
||||
words = utils.GseTool.Extract(docs[0].Content, 13)
|
||||
} else {
|
||||
var docsSplit []*schema.Document
|
||||
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "递归分割文档失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
for _, t := range docsSplit {
|
||||
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量操作前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
|
||||
for _, word := range words {
|
||||
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
|
||||
@@ -227,111 +289,305 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum
|
||||
if len(keywordReqs) > 0 {
|
||||
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "关键字存储失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "关键字提取完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为关键字存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeExtractKeywords,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "没有提取到关键词,关键字提取完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
|
||||
// sqlSplitDocument SQL切分(支持取消)
|
||||
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 语义切分文件
|
||||
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
docsSplitCount = gconv.Int64(len(docsSplit))
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档
|
||||
var vectorDocs = make([]dto.VectorDocumentChunkMsg, 0)
|
||||
var docsChunk = make([]*schema.Document, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
vectorDocs = append(vectorDocs, dto.VectorDocumentChunkMsg{
|
||||
TenantId: doc.TenantId,
|
||||
Creator: doc.Creator,
|
||||
DatasetId: doc.DatasetId,
|
||||
DocumentId: doc.Id,
|
||||
Content: t.Content,
|
||||
ContentHash: contentHash,
|
||||
ChunkIndex: gconv.Int64(i),
|
||||
})
|
||||
|
||||
var metaData = make(map[string]any)
|
||||
metaData[entity.DocumentCol.TenantId] = doc.TenantId
|
||||
metaData[entity.DocumentCol.Creator] = doc.Creator
|
||||
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
|
||||
metaData[entity.DocumentChunkCol.DocumentId] = doc.Id
|
||||
metaData[entity.DocumentChunkCol.ContentHash] = contentHash
|
||||
metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i)
|
||||
t.MetaData = metaData
|
||||
docsChunk = append(docsChunk, t)
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量发送前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 发送消息到队列
|
||||
if len(vectorDocs) > 0 {
|
||||
if len(docsChunk) > 0 {
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentChunkTopic,
|
||||
Data: vectorDocs,
|
||||
Data: docsChunk,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "发送消息到队列失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度进行中 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusRunning,
|
||||
Remark: "向量生成任务已提交到队列",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成向量,任务完成",
|
||||
})
|
||||
}
|
||||
vectorDocsCount = gconv.Int64(len(vectorDocs))
|
||||
return
|
||||
}
|
||||
|
||||
// esSplitDocument ES切分(支持取消)
|
||||
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
|
||||
// ========== 取消检查 1:方法入口 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 1. 加载文件
|
||||
docs, err := s.loadDocument(ctx, doc)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "加载文件失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 递归切分文件
|
||||
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "文档切分失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 获取历史数据
|
||||
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "获取历史数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 组装向量文档并同时构建meilisearch文档
|
||||
var meiliDocs = make([]interface{}, 0)
|
||||
for i, t := range docsSplit {
|
||||
// ========== 取消检查 2:循环内部 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
contentHash := gmd5.MustEncryptString(t.Content)
|
||||
// 检查是否重复
|
||||
var success bool
|
||||
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
|
||||
if err != nil {
|
||||
// 写入任务进度失败 任务类型为es存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "检查重复数据失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !success {
|
||||
continue
|
||||
}
|
||||
// 构建Meilisearch文档
|
||||
meiliDocs = append(meiliDocs, map[string]interface{}{
|
||||
"id": contentHash,
|
||||
"datasetId": doc.DatasetId,
|
||||
"documentId": doc.Id,
|
||||
"content": t.Content,
|
||||
"contentHash": contentHash,
|
||||
"chunkIndex": i,
|
||||
entity.DocumentChunkCol.Id: contentHash,
|
||||
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
|
||||
entity.DocumentChunkCol.DocumentId: doc.Id,
|
||||
entity.DocumentChunkCol.Content: t.Content,
|
||||
entity.DocumentChunkCol.ContentHash: contentHash,
|
||||
entity.DocumentChunkCol.ChunkIndex: i,
|
||||
})
|
||||
}
|
||||
|
||||
// ========== 取消检查 3:批量写入前 ==========
|
||||
if ctx.Err() != nil {
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "ctx取消: " + ctx.Err().Error(),
|
||||
})
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 4. 写入到meilisearch数据库中
|
||||
if len(meiliDocs) > 0 {
|
||||
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
|
||||
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
|
||||
// 写入任务进度失败 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: "写入meilisearch失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "全文检索数据写入完成",
|
||||
})
|
||||
} else {
|
||||
// 写入任务进度已完成 任务类型为meilisearch存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: doc.Id,
|
||||
TaskType: task.TaskTypeFullTextSearch,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "无需生成全文检索数据,任务完成",
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -403,23 +659,12 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
|
||||
|
||||
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
|
||||
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 调用接口获取数据
|
||||
d := &dto.ListDocumentChunkRPC{}
|
||||
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
|
||||
"datasetId", gconv.String(doc.DatasetId),
|
||||
"status", 1); err != nil {
|
||||
return
|
||||
}
|
||||
dictData = d.List
|
||||
res, _, err := dao.DocumentChunk.List(ctx, &dto.ListDocumentChunkReq{
|
||||
DatasetId: doc.DatasetId,
|
||||
Status: gconv.PtrInt8(1),
|
||||
})
|
||||
err = gconv.Struct(res, &dictData)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -464,20 +709,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa
|
||||
success = val.Bool()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = new(dto.KnowledgeDocumentMsg)
|
||||
if err = gconv.Struct(msg, &req); err != nil {
|
||||
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
|
||||
return
|
||||
}
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req.TenantId,
|
||||
UserName: req.Creator,
|
||||
})
|
||||
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
|
||||
Id: req.Id,
|
||||
VectorStatus: req.VectorStatus,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,33 +2,23 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"rag/consts/document"
|
||||
"rag/consts/public"
|
||||
"rag/common/eino"
|
||||
"rag/common/task"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
"rag/model/entity"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/rag/eino"
|
||||
gmq "github.com/bjang03/gmq/core/gmq"
|
||||
"github.com/bjang03/gmq/mq"
|
||||
"github.com/bjang03/gmq/types"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/pgvector/pgvector-go"
|
||||
)
|
||||
|
||||
var DocumentChunk = new(documentChunkService)
|
||||
|
||||
type documentChunkService struct{}
|
||||
|
||||
const (
|
||||
DatasetIndexStatusReady = "ready"
|
||||
)
|
||||
|
||||
// Update 更新文件块
|
||||
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
|
||||
_, err = dao.DocumentChunk.Update(ctx, req)
|
||||
@@ -49,128 +39,46 @@ func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentCh
|
||||
}
|
||||
|
||||
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
|
||||
var req = make([]*dto.VectorDocumentChunkMsg, 0)
|
||||
var docs = make([]*schema.Document, 0)
|
||||
msgMap := gconv.Map(msg)
|
||||
if err = gconv.Structs(msgMap["data"], &req); err != nil {
|
||||
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
return
|
||||
}
|
||||
if len(req) == 0 {
|
||||
if len(docs) == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
|
||||
return
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, "user", &beans.User{
|
||||
TenantId: req[0].TenantId,
|
||||
UserName: req[0].Creator,
|
||||
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]),
|
||||
UserName: gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]),
|
||||
})
|
||||
|
||||
// 调用eino接口获取向量
|
||||
var vectorDocsStr = make([]string, 0, len(req))
|
||||
for _, t := range req {
|
||||
vectorDocsStr = append(vectorDocsStr, t.Content)
|
||||
}
|
||||
embeddings, err := eino.EmbedStrings(ctx, vectorDocsStr)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
|
||||
BatchSize: 10,
|
||||
})
|
||||
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
|
||||
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
|
||||
if err != nil || rows == 0 {
|
||||
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
|
||||
// 写入任务进度失败 任务类型为sql存储
|
||||
remark := " 向量存储数量: " + gconv.String(rows)
|
||||
if err != nil {
|
||||
remark = "向量存储失败: " + err.Error()
|
||||
}
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusFailed,
|
||||
Remark: remark,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取向量维度
|
||||
dimension := 0
|
||||
if len(embeddings) > 0 {
|
||||
dimension = len(embeddings[0])
|
||||
}
|
||||
|
||||
// 创建或更新DatasetIndex
|
||||
err = s.createOrUpdateDatasetIndex(ctx, req[0].DatasetId, dimension, int64(len(req)))
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "CreateOrUpdateDatasetIndex err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新向量文档
|
||||
for i, embedding := range embeddings {
|
||||
req[i].Vector = pgvector.NewVector(gconv.Float32s(embedding))
|
||||
req[i].VectorStatus = document.VectorStatusCompleted.Code()
|
||||
req[i].Status = document.StatusEnable.Code()
|
||||
}
|
||||
_, err = dao.DocumentChunk.BatchInsert(ctx, req)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "DocsChunkMsg err:", err)
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusFailed.Code())
|
||||
return
|
||||
}
|
||||
|
||||
err = s.publishKnowledgeDocumentMsg(ctx, req[0].TenantId, req[0].Creator, req[0].DocumentId, document.VectorStatusCompleted.Code())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// createOrUpdateDatasetIndex 创建或更新数据集索引
|
||||
func (s *documentChunkService) createOrUpdateDatasetIndex(ctx context.Context, datasetId int64, dimension int, vectorCount int64) (err error) {
|
||||
// 查询数据集是否已有索引
|
||||
existIndex, err := dao.DatasetIndex.GetByDatasetId(ctx, datasetId)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return err
|
||||
}
|
||||
|
||||
// 已有索引 → 只更新数量
|
||||
if existIndex != nil {
|
||||
_ = dao.DatasetIndex.IncVectorCount(ctx, existIndex.Id, vectorCount)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ====================== 创建新索引 ======================
|
||||
indexName := fmt.Sprintf("idx_dataset_%d_vector", datasetId) // 真实PG索引名
|
||||
// 1. 插入索引配置
|
||||
index := &entity.DatasetIndex{
|
||||
DatasetId: datasetId,
|
||||
Name: indexName,
|
||||
Dimension: dimension,
|
||||
FieldType: "float",
|
||||
MetricType: "COSINE",
|
||||
Status: gconv.PtrInt8(1),
|
||||
VectorCount: vectorCount,
|
||||
Description: fmt.Sprintf("数据集%d向量索引", datasetId),
|
||||
}
|
||||
_, err = dao.DatasetIndex.Insert(ctx, index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 真正创建 PGVector 索引(唯一真实索引!)
|
||||
err = s.createRealPGVectorIndex(ctx, indexName)
|
||||
return err
|
||||
}
|
||||
|
||||
// createRealPGVectorIndex 真正在PostgreSQL创建向量索引(真实可用)
|
||||
func (s *documentChunkService) createRealPGVectorIndex(ctx context.Context, indexName string) error {
|
||||
// 执行真实建索引语句
|
||||
err := dao.DatasetIndex.InsertIndex(ctx, indexName)
|
||||
if err != nil {
|
||||
g.Log().Error(ctx, "创建向量索引失败:", err)
|
||||
return err
|
||||
}
|
||||
g.Log().Info(ctx, "PGVector真实索引创建成功:"+indexName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishKnowledgeDocumentMsg 发布消息
|
||||
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
|
||||
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
|
||||
TenantId: tenantId,
|
||||
Creator: creator,
|
||||
Id: documentId,
|
||||
VectorStatus: vectorStatus,
|
||||
}
|
||||
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
|
||||
PubMessage: types.PubMessage{
|
||||
Topic: public.KnowledgeDocumentVectorStatusTopic,
|
||||
Data: knowledgeDocumentMsg,
|
||||
},
|
||||
// 写入任务进度成功 任务类型为sql存储
|
||||
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
|
||||
TaskId: documentId,
|
||||
TaskType: task.TaskTypeGenerateVector,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "向量生成完成",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
60
service/rag_query.go
Normal file
60
service/rag_query.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"rag/common/eino"
|
||||
"rag/model/dto"
|
||||
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var RAGQuery = new(ragQueryService)
|
||||
|
||||
type ragQueryService struct{}
|
||||
|
||||
// Query 执行RAG查询
|
||||
func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
|
||||
if req.TopK <= 0 {
|
||||
req.TopK = 5
|
||||
}
|
||||
|
||||
// 4. 使用向量检索器进行查询
|
||||
r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{
|
||||
Embedder: eino.EmbedderDashscope,
|
||||
DefaultTopK: req.TopK,
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "初始化向量检索器失败: %v", err)
|
||||
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 执行向量检索
|
||||
docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{
|
||||
"dataset_ids": req.DatasetIds,
|
||||
}))
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
messages := make([]*schema.Message, 0)
|
||||
err = gconv.Struct(req.History, &messages)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "转换历史消息失败: %v", err)
|
||||
return nil, fmt.Errorf("转换历史消息失败: %w", err)
|
||||
}
|
||||
|
||||
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages)
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "向量检索失败: %v", err)
|
||||
return nil, fmt.Errorf("向量检索失败: %w", err)
|
||||
}
|
||||
|
||||
return &dto.RAGQueryRes{
|
||||
Answer: replyMsg.Content,
|
||||
}, nil
|
||||
}
|
||||
116
service/task.go
Normal file
116
service/task.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"rag/dao"
|
||||
"rag/model/dto"
|
||||
|
||||
"rag/common/task"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = new(taskService)
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
// WriteTaskProgress 写入任务进度(核心方法)
|
||||
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
|
||||
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
completed := false
|
||||
if total != 0 {
|
||||
taskVO := make([]*dto.TaskVO, 0, total)
|
||||
err = gconv.Struct(t, &taskVO)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "转换任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
taskVO = append(taskVO, &dto.TaskVO{
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
})
|
||||
completed = IsAllSubTasksCompleted(taskVO)
|
||||
}
|
||||
|
||||
// 1. 查询是否已存在该文档的该类型任务
|
||||
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "查询任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
// 2. 如果不存在,则创建新任务
|
||||
if g.IsEmpty(existTask) {
|
||||
createReq := &dto.CreateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
TaskType: req.TaskType,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, createReq)
|
||||
} else {
|
||||
// 3. 如果已存在,则更新任务
|
||||
updateReq := &dto.UpdateTaskReq{
|
||||
Id: existTask[0].Id,
|
||||
Status: req.Status,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
_, err = dao.Task.Update(ctx, updateReq)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "更新任务失败: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if completed {
|
||||
// 3. 如果已存在,则更新任务
|
||||
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
|
||||
TaskId: req.TaskId,
|
||||
Status: task.TaskStatusCompleted,
|
||||
Remark: "文档解析完成",
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// IsAllSubTasksCompleted 判断三个子任务是否全部完成
|
||||
// 参数:传入当前文档的所有子任务列表
|
||||
func IsAllSubTasksCompleted(subTasks []*dto.TaskVO) bool {
|
||||
// 必须包含 3 种任务类型
|
||||
hasKeywords := false
|
||||
hasVector := false
|
||||
hasFullText := false
|
||||
|
||||
for _, t := range subTasks {
|
||||
// 子任务必须是【已完成】状态才计数
|
||||
if t.Status == task.TaskStatusCompleted {
|
||||
switch t.TaskType {
|
||||
case task.TaskTypeExtractKeywords:
|
||||
hasKeywords = true
|
||||
case task.TaskTypeGenerateVector:
|
||||
hasVector = true
|
||||
case task.TaskTypeFullTextSearch:
|
||||
hasFullText = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 三个任务全部完成 → 返回true
|
||||
return hasKeywords && hasVector && hasFullText
|
||||
}
|
||||
BIN
timezone/Shanghai
Normal file
BIN
timezone/Shanghai
Normal file
Binary file not shown.
BIN
timezone/localtime
Normal file
BIN
timezone/localtime
Normal file
Binary file not shown.
1
timezone/timezone
Normal file
1
timezone/timezone
Normal file
@@ -0,0 +1 @@
|
||||
Asia/Shanghai
|
||||
167
update.sql
167
update.sql
@@ -114,6 +114,7 @@ COMMENT ON COLUMN rag_knowledge_document.file_path IS '文件存储路径(如M
|
||||
COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}';
|
||||
|
||||
--------------------pgsql创建rag_knowledge_document表语句---------------------------
|
||||
|
||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||
-- 关键词表(文档关键词+权重)
|
||||
CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
|
||||
@@ -134,9 +135,9 @@ CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
|
||||
);
|
||||
|
||||
-- 唯一索引:保证 租户 + 数据集 + 文档 + 关键词 全局唯一
|
||||
CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word
|
||||
ON rag_knowledge_keyword(tenant_id, dataset_id, document_id, word)
|
||||
WHERE deleted_at IS NULL;
|
||||
-- CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word
|
||||
-- ON rag_knowledge_keyword(tenant_id, dataset_id, document_id, word)
|
||||
-- WHERE deleted_at IS NULL;
|
||||
|
||||
-- 索引(按业务高频查询)
|
||||
CREATE INDEX idx_keyword_tenant_id ON rag_knowledge_keyword(tenant_id);
|
||||
@@ -158,5 +159,163 @@ COMMENT ON COLUMN rag_knowledge_keyword.dataset_id IS '数据集ID';
|
||||
COMMENT ON COLUMN rag_knowledge_keyword.document_id IS '文档ID';
|
||||
COMMENT ON COLUMN rag_knowledge_keyword.word IS '关键词';
|
||||
COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
|
||||
CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word ON rag_knowledge_keyword (tenant_id, dataset_id, document_id, word);
|
||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||
|
||||
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
|
||||
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||
-- 知识库任务表
|
||||
CREATE TABLE IF NOT EXISTS rag_knowledge_task (
|
||||
-- 基础字段(完全对齐项目规范)
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at timestamp(6),
|
||||
|
||||
-- 业务字段
|
||||
task_id BIGINT NOT NULL, -- 任务ID
|
||||
task_type VARCHAR(32) NOT NULL, -- 任务类型
|
||||
status VARCHAR(32) NOT NULL, -- 任务状态
|
||||
executor VARCHAR(128) DEFAULT '', -- 执行器
|
||||
remark TEXT DEFAULT '' -- 备注
|
||||
);
|
||||
|
||||
-- 索引(高频查询)
|
||||
CREATE INDEX idx_rkt_tenant_id ON rag_knowledge_task(tenant_id);
|
||||
CREATE INDEX idx_rkt_task_id ON rag_knowledge_task(task_id);
|
||||
CREATE INDEX idx_rkt_task_type ON rag_knowledge_task(task_type);
|
||||
CREATE INDEX idx_rkt_status ON rag_knowledge_task(status);
|
||||
CREATE INDEX idx_rkt_deleted_at ON rag_knowledge_task(deleted_at);
|
||||
|
||||
-- 表和字段注释
|
||||
COMMENT ON TABLE rag_knowledge_task IS '知识库任务表';
|
||||
COMMENT ON COLUMN rag_knowledge_task.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN rag_knowledge_task.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN rag_knowledge_task.creator IS '创建人';
|
||||
COMMENT ON COLUMN rag_knowledge_task.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN rag_knowledge_task.updater IS '更新人';
|
||||
COMMENT ON COLUMN rag_knowledge_task.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN rag_knowledge_task.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN rag_knowledge_task.task_id IS '任务ID';
|
||||
COMMENT ON COLUMN rag_knowledge_task.task_type IS '任务类型';
|
||||
COMMENT ON COLUMN rag_knowledge_task.status IS '任务状态';
|
||||
COMMENT ON COLUMN rag_knowledge_task.executor IS '执行器';
|
||||
COMMENT ON COLUMN rag_knowledge_task.remark IS '备注';
|
||||
|
||||
--------------------pgsql创建rag_knowledge_task表语句---------------------------
|
||||
|
||||
|
||||
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
|
||||
-- 向量数据集索引表
|
||||
CREATE TABLE IF NOT EXISTS rag_vector_dataset_index (
|
||||
-- 基础字段
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at timestamp(6),
|
||||
|
||||
-- 核心字段
|
||||
dataset_id INT8 NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
collection VARCHAR(255) NOT NULL,
|
||||
dimension INT NOT NULL,
|
||||
field_type VARCHAR(50) NOT NULL,
|
||||
metric_type VARCHAR(50) NOT NULL,
|
||||
status SMALLINT NOT NULL DEFAULT 1, -- 状态:1启用/0停用
|
||||
vector_count INT8 NOT NULL DEFAULT 0,
|
||||
description TEXT
|
||||
);
|
||||
|
||||
-- 唯一约束
|
||||
ALTER TABLE rag_vector_dataset_index ADD CONSTRAINT uk_dataset_id_name UNIQUE (dataset_id, name);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_dataset_index_tenant_id ON rag_vector_dataset_index(tenant_id);
|
||||
CREATE INDEX idx_dataset_index_dataset_id ON rag_vector_dataset_index(dataset_id);
|
||||
CREATE INDEX idx_dataset_index_status ON rag_vector_dataset_index(status);
|
||||
|
||||
-- 注释
|
||||
COMMENT ON TABLE rag_vector_dataset_index IS '向量数据集索引表';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.creator IS '创建人';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.updater IS '更新人';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.dataset_id IS '数据集ID';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.name IS '索引名称';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.collection IS '向量集合名称';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.dimension IS '向量维度';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.field_type IS '字段类型';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.metric_type IS '度量类型';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.status IS '状态';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.vector_count IS '向量数量';
|
||||
COMMENT ON COLUMN rag_vector_dataset_index.description IS '描述';
|
||||
|
||||
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
|
||||
|
||||
--------------------pgsql创建rag_vector_document_chunk表语句---------------------------
|
||||
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- 文档分块向量表
|
||||
CREATE TABLE IF NOT EXISTS rag_vector_document_chunk (
|
||||
-- 基础字段
|
||||
id BIGINT PRIMARY KEY, -- 主键ID(非自增)
|
||||
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
|
||||
creator VARCHAR(64) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updater VARCHAR(64) NOT NULL,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at timestamp(6),
|
||||
|
||||
-- 核心字段
|
||||
status SMALLINT NOT NULL DEFAULT 1, -- 状态:1启用/0停用
|
||||
vector_status SMALLINT NOT NULL DEFAULT 1, -- 向量化状态: 1pending, 2processing, 3completed, 4failed,5partCompleted
|
||||
dataset_id INT8 NOT NULL,
|
||||
document_id INT8 NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
content_hash VARCHAR(128) NOT NULL,
|
||||
chunk_index INT8 NOT NULL,
|
||||
|
||||
-- 向量字段(pgvector)
|
||||
vector vector(1024) NOT NULL,
|
||||
|
||||
-- 扩展信息
|
||||
metadata JSONB
|
||||
);
|
||||
|
||||
-- 索引
|
||||
CREATE INDEX idx_chunk_tenant_id ON rag_vector_document_chunk(tenant_id);
|
||||
CREATE INDEX idx_chunk_dataset_id ON rag_vector_document_chunk(dataset_id);
|
||||
CREATE INDEX idx_chunk_document_id ON rag_vector_document_chunk(document_id);
|
||||
CREATE INDEX idx_chunk_content_hash ON rag_vector_document_chunk(content_hash);
|
||||
CREATE INDEX idx_chunk_status ON rag_vector_document_chunk(status);
|
||||
CREATE INDEX idx_chunk_vector_status ON rag_vector_document_chunk(vector_status);
|
||||
|
||||
-- 注释
|
||||
COMMENT ON TABLE rag_vector_document_chunk IS '文档分块向量表';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.id IS '主键ID(非自增)';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.tenant_id IS '租户ID';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.creator IS '创建人';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.created_at IS '创建时间';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.updater IS '更新人';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.updated_at IS '更新时间';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.deleted_at IS '删除时间(软删)';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.status IS '状态';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.vector_status IS '向量生成状态';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.dataset_id IS '数据集ID';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.document_id IS '文档ID';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.content IS '分块内容';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.content_hash IS '内容哈希';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.chunk_index IS '分块序号';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.vector IS '向量数据';
|
||||
COMMENT ON COLUMN rag_vector_document_chunk.metadata IS '扩展元数据';
|
||||
|
||||
--------------------pgsql创建rag_vector_document_chunk表语句---------------------------
|
||||
Reference in New Issue
Block a user