Compare commits

..

16 Commits

Author SHA1 Message Date
eea10519a3 ci/cd调整 2026-06-10 16:36:46 +08:00
qhd
61fbd50e3d feat: 添加文档处理API和配置更新 2026-04-22 17:58:07 +08:00
92e9d6b4ff gmq版本 2026-04-21 11:51:20 +08:00
42afbf878c gmq版本 2026-04-21 09:16:54 +08:00
qhd
27b1dd3c27 feat: 支持多租户多模型对话及文档去重优化 2026-04-16 15:47:37 +08:00
qhd
4ead3f82cf chore: 更新依赖并清理未使用导入 2026-04-11 18:38:08 +08:00
qhd
a05cac7591 feat: 新增关键词类型及优化查询逻辑
支持关键词类型区分,优化文件向量查询SQL及DAO更新逻辑,移除冗余配置和注释代码。
2026-04-11 18:24:37 +08:00
qhd
94df015aa9 refactor: 重构文档向量相关代码结构 2026-04-10 13:12:19 +08:00
a7b8713e26 增加dockerfile配置 2026-04-09 17:42:57 +08:00
b2f7cff277 增加dockerfile配置 2026-04-09 16:14:32 +08:00
93aef365e7 Merge remote-tracking branch 'origin/dev' into dev
# Conflicts:
#	service/document.go
2026-04-09 15:58:27 +08:00
cfcf705503 增加dockerfile配置 2026-04-09 15:58:05 +08:00
qhd
2ced0a43e5 feat: 优化RAG检索与聊天模型支持历史对话
实现双路检索并行优化,使用EINO官方模板重构聊天逻辑,增加多轮对话历史记录管理及相关性过滤,并修复数据库唯一索引。
2026-04-09 13:57:46 +08:00
qhd
14a429f4ae chore: 移除 Elasticsearch 依赖 2026-04-09 09:14:20 +08:00
qhd
ff5fc54b35 Merge branch 'dev' of http://116.204.74.41:3000/red-future/rag into dev
# Conflicts:
#	go.mod
#	go.sum
2026-04-09 09:13:11 +08:00
qhd
7f894745e9 refactor: 重构文档处理流程和任务管理 2026-04-09 09:11:43 +08:00
65 changed files with 3908 additions and 2353 deletions

View File

@@ -1,24 +1,24 @@
# 最小化Docker镜像
FROM busybox:uclibc
# 阶段1: 构建
FROM golang:alpine AS builder
WORKDIR /app
RUN apk add --no-cache git ca-certificates tzdata
# 复制时区数据
COPY timezone/localtime /etc/localtime
COPY timezone/timezone /etc/timezone
COPY timezone/Shanghai /usr/share/zoneinfo/Asia/Shanghai
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# 复制预构建的二进制文件和配置文件
COPY rag_binary ./main
COPY config.yml ./
ENV GO111MODULE=on
ENV GOPROXY=https://goproxy.cn,direct
ENV CGO_ENABLED=0
ENV GOTOOLCHAIN=auto
WORKDIR /build
# 创建日志目录
RUN mkdir -p /logs /app/resource/log/run /app/resource/log/server
COPY . .
# 添加执行权限
RUN chmod +x /app/main
RUN go mod download && go mod tidy
EXPOSE 3008
RUN go build -ldflags="-s -w" -o main ./main.go
EXPOSE 3006
# 使用root用户运行
CMD ["./main"]

View File

@@ -1,166 +0,0 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-ext/components/model/ark"
)
func main() {
ctx := context.Background()
// ==========================================
// 1. 初始化三大组件
// ==========================================
// 1.1 向量检索(从知识库查客服知识)
ragRetriever := NewPGVectorRetriever()
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
chatTpl := newCustomerServiceTemplate()
// 1.3 大模型ARK
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL_ID"),
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 2. 模拟会话:从 DB 读取历史对话
// ==========================================
sessionHistory := []*schema.Message{
{Role: schema.User, Content: "你们发什么快递?"},
{Role: schema.Assistant, Content: "默认发中通快递"},
{Role: schema.User, Content: "可以发顺丰吗?"},
}
// 当前用户问题
userQuery := "那顺丰需要加钱吗?"
// ==========================================
// 3. RAG 检索知识库
// ==========================================
docs, err := ragRetriever.Retrieve(ctx, userQuery)
if err != nil {
log.Fatal(err)
}
// 拼接参考知识
knowledge := ""
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
// ==========================================
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
// ==========================================
msgs, err := chatTpl.Format(ctx, map[string]any{
"history": sessionHistory,
"knowledge": knowledge,
"question": userQuery,
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 5. 流式调用大模型生成客服回答
// ==========================================
fmt.Println("\n=== 客服回复 ===")
stream, err := chatModel.Stream(ctx, msgs)
if err != nil {
log.Fatal(err)
}
fullReply := make([]*schema.Message, 0, 100)
for {
chunk, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Print(chunk.Content)
fullReply = append(fullReply, chunk)
}
// ==========================================
// 6. 拼接完整回复,存入 DB 作为新历史
// ==========================================
replyMsg, _ := schema.ConcatMessages(fullReply)
sessionHistory = append(sessionHistory,
&schema.Message{Role: schema.User, Content: userQuery},
replyMsg,
)
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
}
// ==========================================
// 本地客服提示词模板(不需要 MCP
// ==========================================
func newCustomerServiceTemplate() prompt.ChatTemplate {
// 系统提示 + 多轮对话 + 知识库 + 用户问题
return prompt.FromMessages(schema.Messages{
{
Role: schema.System,
Content: `你是电商智能客服,语气友好简洁。
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
参考知识:
{{.knowledge}}`,
},
// 历史对话会自动渲染在这里
{{range .history}}{{.}},{{end}},
// 当前用户问题
{Role: schema.User, Content: "{{.question}}"},
})
}
// ==========================================
// PGVector 检索器(简化可直接用)
// ==========================================
type PGVectorRetriever struct {
topK int
}
func NewPGVectorRetriever() retriever.Retriever {
return &PGVectorRetriever{topK: 3}
}
func (r *PGVectorRetriever) Retrieve(
ctx context.Context,
query string,
opts ...retriever.Option,
) ([]*schema.Document, error) {
options := retriever.GetCommonOptions(nil, opts...)
topK := r.topK
if options.TopK != nil {
topK = *options.TopK
}
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
// 模拟知识库
return []*schema.Document{
{
ID: "1",
Content: "顺丰快递需要补10元运费差价",
},
{
ID: "2",
Content: "订单满99元可免费升级顺丰",
},
}, nil
}

View File

@@ -1,107 +0,0 @@
package eino
import (
"context"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/cloudwego/eino-ext/components/indexer/es8"
)
const (
indexName = "eino_example"
fieldContent = "content"
fieldContentVector = "content_vector"
fieldExtraLocation = "location"
docExtraLocation = "location"
)
func TestIndexer() {
ctx := context.Background()
// 1. 创建 ES 客户端
client, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
if err != nil {
fmt.Printf("create client error: %v\n", err)
return
}
// 2. 定义 Index Spec选填如果索引不存在将自动创建
indexSpec := &es8.IndexSpec{
Settings: map[string]any{
"number_of_shards": 1,
"number_of_replicas": 0,
},
Mappings: map[string]any{
"properties": map[string]any{
fieldContentVector: map[string]any{
"type": "dense_vector",
"dims": 1024,
"index": true,
"similarity": "l2_norm",
},
},
},
}
// 4. 准备文档
// 文档通常包含 ID 和 Content
// 也可以包含额外的 Metadata 用于过滤或其他用途
docs := []*schema.Document{
{
ID: "1",
Content: "Eiffel Tower: Located in Paris, France.",
MetaData: map[string]any{
docExtraLocation: "France",
},
},
{
ID: "2",
Content: "The Great Wall: Located in China.",
MetaData: map[string]any{
docExtraLocation: "China",
},
},
}
// 5. 创建 ES 索引器组件
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
Client: client,
Index: indexName,
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
BatchSize: 10,
// DocumentToFields 指定如何将文档字段映射到 ES 字段
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
return map[string]es8.FieldValue{
fieldContent: {
Value: doc.Content,
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
},
fieldExtraLocation: {
// 额外的 metadata 字段
Value: doc.MetaData[docExtraLocation],
},
}, nil
},
// 提供 embedding 组件用于向量化
Embedding: EmbedderDashscope,
})
if err != nil {
fmt.Printf("create indexer error: %v\n", err)
return
}
// 6. 索引文档
ids, err := indexer.Store(ctx, docs)
if err != nil {
fmt.Printf("index error: %v\n", err)
return
}
fmt.Println("indexed ids:", ids)
}

View File

@@ -1,49 +0,0 @@
package eino
import (
"time"
"gitea.com/red-future/common/beans"
)
// BaseTask 任务基类 - MongoDB版本
type BaseTask struct {
beans.MongoBaseDO `bson:",inline"`
// 任务信息
TaskType TaskType `bson:"taskType" json:"taskType"`
Status TaskStatus `bson:"status" json:"status"`
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
// 进度
TotalItems int64 `bson:"totalItems" json:"totalItems"`
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
Progress float64 `bson:"progress" json:"progress"`
// 结果
StartTime *time.Time `bson:"startTime" json:"startTime"`
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
SuccessCount int64 `bson:"successCount" json:"successCount"`
FailCount int64 `bson:"failCount" json:"failCount"`
// 其他
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
}
// SQLBaseTask 任务基类 - SQL版本
type SQLBaseTask struct {
beans.SQLBaseDO
// 任务信息
TaskType TaskType `json:"taskType"`
Status TaskStatus `json:"status"`
Priority TaskPriority `json:"priority,omitempty"`
// 进度
TotalItems int64 `json:"totalItems"`
ProcessedItems int64 `json:"processedItems"`
Progress float64 `json:"progress"`
// 结果
StartTime *time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration int64 `json:"duration,omitempty"`
SuccessCount int64 `json:"successCount"`
FailCount int64 `json:"failCount"`
// 其他
Executor string `json:"executor,omitempty"`
}

View File

@@ -1,94 +0,0 @@
package eino
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
)
func TestRetriever() {
ctx := context.Background()
client, _ := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
// 创建 retriever 组件
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
Client: client,
Index: indexName,
TopK: 5,
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
QueryFieldName: fieldContent,
VectorFieldName: fieldContentVector,
Hybrid: false,
// RRF 仅在特定许可证下可用
// 参见: https://www.elastic.co/subscriptions
RRF: false,
RRFRankConstant: nil,
RRFWindowSize: nil,
}),
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
doc = &schema.Document{
ID: *hit.Id_,
Content: "",
MetaData: map[string]any{},
}
var src map[string]any
if err = json.Unmarshal(hit.Source_, &src); err != nil {
return nil, err
}
for field, val := range src {
switch field {
case fieldContent:
doc.Content = val.(string)
case fieldContentVector:
var v []float64
for _, item := range val.([]interface{}) {
v = append(v, item.(float64))
}
doc.WithDenseVector(v)
case fieldExtraLocation:
doc.MetaData[docExtraLocation] = val.(string)
}
}
if hit.Score_ != nil {
doc.WithScore(float64(*hit.Score_))
}
return doc, nil
},
Embedding: EmbedderDashscope,
})
// 不带过滤器的搜索
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
// 带过滤器的搜索
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
es8.WithFilters([]types.Query{{
Term: map[string]types.TermQuery{
fieldExtraLocation: {
CaseInsensitive: of(true),
Value: "China",
},
},
}}),
)
fmt.Printf("retrieved docs: %+v\n", docs)
}
func of[T any](v T) *T {
return &v
}

243
common/eino/chat.go Normal file
View File

@@ -0,0 +1,243 @@
package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/arkbot"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qianfan"
"github.com/cloudwego/eino-ext/components/model/qwen"
modelChat "github.com/cloudwego/eino/components/model"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
type ChatModelSet struct {
Ark *ark.ChatModel
ArkBot *arkbot.ChatModel
Claude *claude.ChatModel
DeepSeek *deepseek.ChatModel
Ollama *ollama.ChatModel
OpenAI *openai.ChatModel
Qianfan *qianfan.ChatModel
Qwen *qwen.ChatModel
}
// 全局租户容器key=tenantIdvalue=该租户的对话模型
var tenantChatModels = make(map[uint64]*ChatModelSet)
func init() {
ctx := context.Background()
ctx, span := jaeger.NewSpan(ctx, "InitAllChat")
defer span.End()
InitAllChat(ctx)
return
}
// ===================== 1. 服务启动时:初始化所有租户对话模型 =====================
func InitAllChat(ctx context.Context) {
list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
ModelType: model.ModelTypeChat.Code(),
})
if err != nil {
g.Log().Errorf(ctx, "获取所有租户对话模型失败: %v", err)
return
}
for _, l := range list {
err = InitChat(ctx, l)
if err != nil {
g.Log().Errorf(ctx, "初始化租户[%v]的对话模型失败: %v", l.TenantId, err)
continue
}
}
}
func InitChat(ctx context.Context, modelDO *entity.Model) (err error) {
set := &ChatModelSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
var cfg entity.ChatModelConfigArk
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ark配置失败: %v", err)
}
set.Ark, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatArkBot.Code():
var cfg entity.ChatModelConfigArkBot
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析ArkBot配置失败: %v", err)
}
set.ArkBot, err = arkbot.NewChatModel(ctx, &arkbot.Config{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatClaude.Code():
var cfg entity.ChatModelConfigClaude
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Claude配置失败: %v", err)
}
claudeCfg := claude.Config{
APIKey: cfg.APIKey,
BaseURL: gconv.PtrString(cfg.BaseURL),
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.PtrFloat32(1.0),
ByBedrock: cfg.ByBedrock,
AccessKey: cfg.AccessKey,
SecretAccessKey: cfg.SecretAccessKey,
Region: cfg.Region,
}
set.Claude, err = claude.NewChatModel(ctx, &claudeCfg)
case *model.ModelConfigTypeChatDeepSeek.Code():
var cfg entity.ChatModelConfigDeepSeek
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析DeepSeek配置失败: %v", err)
}
set.DeepSeek, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.Float32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.Float32(1.0),
})
case *model.ModelConfigTypeChatOllama.Code():
var cfg entity.ChatModelConfigOllama
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ollama配置失败: %v", err)
}
set.Ollama, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeChatOpenAI.Code():
var cfg entity.ChatModelConfigOpenAI
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析OpenAI配置失败: %v", err)
}
openAiCfg := openai.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
}
set.OpenAI, err = openai.NewChatModel(ctx, &openAiCfg)
case *model.ModelConfigTypeChatQianfan.Code():
var cfg entity.ChatModelConfigQianfan
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析千帆配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewChatModel(ctx, &qianfan.ChatModelConfig{
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatQwen.Code():
var cfg entity.ChatModelConfigQwen
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Qwen配置失败: %v", err)
}
set.Qwen, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
default:
return fmt.Errorf("不支持的对话模型类型: %v", *modelDO.ConfigType)
}
if err != nil {
return fmt.Errorf("初始化对话模型失败: %v", err)
}
// 无锁存入租户 map
tenantChatModels[modelDO.TenantId] = set
g.Log().Infof(ctx, "租户[%v]对话模型[%v]初始化成功", modelDO.TenantId, *modelDO.ConfigType)
return
}
func GetTenantChatModel(tenantId uint64) (*ChatModelSet, error) {
set := tenantChatModels[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]对话模型未初始化", tenantId)
}
return set, nil
}
func GetTenantChatModelByType(ctx context.Context, configType model.ModelConfigType) (modelChat.BaseChatModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantChatModel(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeChatArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeChatArkBot.Code():
return set.ArkBot, nil
case *model.ModelConfigTypeChatClaude.Code():
return set.Claude, nil
case *model.ModelConfigTypeChatDeepSeek.Code():
return set.DeepSeek, nil
case *model.ModelConfigTypeChatOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeChatOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeChatQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeChatQwen.Code():
return set.Qwen, nil
default:
return nil, fmt.Errorf("不支持的对话模型类型: %v", configType)
}
}
func RefreshTenantChatModel(ctx context.Context, modelDO *entity.Model) error {
delete(tenantChatModels, modelDO.TenantId)
return InitChat(ctx, modelDO)
}

125
common/eino/chat_model.go Normal file
View File

@@ -0,0 +1,125 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"rag/consts/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
)
const (
MaxHistoryTurns = 5 // 最大历史轮数
)
var (
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
)
func init() {
// 初始化 EINO 提示词模板
initRAGPromptTemplate()
return
}
// 初始化 EINO 官方提示词模板(最关键!)
func initRAGPromptTemplate() {
ragPromptTemplate = prompt.FromMessages(
schema.FString,
// 系统提示(带参考知识)
&schema.Message{
Role: schema.System,
Content: `你是专业客服,语气友好简洁。
请依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
},
// 用户问题
&schema.Message{
Role: schema.User,
Content: "{question}",
},
)
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message, chatModel model.ModelConfigType) (replyMsg *schema.Message, err error) {
// 1. 构建参考知识
knowledge := buildKnowledgeAndSources(docs)
// 2. 历史精简
history = limitHistory(history)
// 3. ✅ EINO 官方模板格式化(超级干净)
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
"knowledge": knowledge,
"question": question,
})
if err != nil {
return nil, err
}
// 4. 历史插入到模板消息中间标准EINO用法
if len(history) > 0 {
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
}
// 5. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, msgs, chatModel)
return
}
func limitHistory(history []*schema.Message) []*schema.Message {
valid := make([]*schema.Message, 0, len(history))
for _, m := range history {
if m.Role == schema.User || m.Role == schema.Assistant {
valid = append(valid, m)
}
}
keep := 2 * MaxHistoryTurns
if len(valid) > keep {
valid = valid[len(valid)-keep:]
}
return valid
}
// buildKnowledgeAndSources 拼接参考知识
func buildKnowledgeAndSources(docs []*schema.Document) string {
var knowledge string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
return knowledge
}
// streamGenerateAnswer 流式生成
func streamGenerateAnswer(ctx context.Context, msgs []*schema.Message, chatModel model.ModelConfigType) (reply *schema.Message, err error) {
cm, err := GetTenantChatModelByType(ctx, chatModel)
if err != nil {
return nil, err
}
sr, err := cm.Stream(ctx, msgs)
if err != nil {
return nil, fmt.Errorf("stream failed: %w", err)
}
var chunks []*schema.Message
for {
chunk, err := sr.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("stream recv failed: %w", err)
}
chunks = append(chunks, chunk)
}
return schema.ConcatMessages(chunks)
}

View File

@@ -1,8 +0,0 @@
package eino
const (
providerArk = "ark"
providerOpenai = "openai"
providerQianfan = "qianfan"
providerDashscope = "dashscope"
)

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino-ext/components/document/loader/url"
"github.com/cloudwego/eino-ext/components/document/parser/docx"
"github.com/cloudwego/eino-ext/components/document/parser/pdf"
@@ -15,7 +16,7 @@ import (
)
// LoadDocument 业务函数:加载文件
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
func a(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
p, err := docsParser(ctx, fileFormat)
if err != nil {
return
@@ -27,12 +28,34 @@ func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*sch
if err != nil {
return
}
docs, err = loader.Load(context.Background(), document.Source{
docs, err = loader.Load(ctx, document.Source{
URI: fmt.Sprintf("%s%s", imageUrl, filePath),
})
return
}
func LoadDocument(ctx context.Context, filePath, fileFormat string) (docs []*schema.Document, err error) {
p, err := docsParser(ctx, fileFormat)
if err != nil {
return
}
// 1. 创建文件加载器
loader, err := file.NewFileLoader(ctx, &file.FileLoaderConfig{
UseNameAsID: false, // 使用文件名作为文档ID
Parser: p,
})
if err != nil {
return
}
// 2. 加载本地文件
docs, err = loader.Load(ctx, document.Source{
URI: "C:\\Users\\AI\\Desktop\\手机发展史.txt",
})
return
}
func docsParser(ctx context.Context, fileFormat string) (p parser.Parser, err error) {
switch fileFormat {
case "docx":

View File

@@ -2,6 +2,7 @@ package eino
import (
"context"
"rag/consts/model"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
@@ -10,7 +11,7 @@ import (
)
// SemanticSplitDocument 语义分割文档
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document, vectorModel model.ModelConfigType) (res []*schema.Document, err error) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
@@ -18,24 +19,14 @@ func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 使用批量包装器
var batchEmbedder *BatchEmbedder
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
case providerOpenai:
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
case providerDashscope:
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
embedder, err := GetTenantEmbedderByType(ctx, vectorModel)
if err != nil {
return nil, err
}
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
Embedding: NewBatchEmbedder(embedder, batchSize),
BufferSize: bufferSize,
MinChunkSize: minChunkSize,
Percentile: percentile,

View File

@@ -3,67 +3,211 @@ package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/golang/glog"
"github.com/gogf/gf/v2/util/gconv"
)
// 全局只初始化一次
var (
EmbedderArk *ark.Embedder
EmbedderDashscope *dashscope.Embedder
EmbedderOpenAI *openai.Embedder
)
type EmbedderSet struct {
Ark *ark.Embedder
Ollama *ollama.Embedder
OpenAI *openai.Embedder
Qianfan *qianfan.Embedder
TencentCloud *tencentcloud.Embedder
DashScope *dashscope.Embedder
}
// 全局租户容器key=tenantIdvalue=该租户的向量模型
var tenantEmbedders = make(map[uint64]*EmbedderSet)
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
var err error
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
cfg := &ark.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
apiTypeVal := ark.APIType(apiType)
cfg.APIType = &apiTypeVal
}
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
case providerOpenai:
chatModelConfig := &openai.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
case providerDashscope:
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
}
if err != nil {
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
}
}
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
defer span.End()
InitAllVector(ctx)
return
}
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
return EmbedderArk.EmbedStrings(ctx, texts)
case providerOpenai:
return EmbedderOpenAI.EmbedStrings(ctx, texts)
case providerDashscope:
return EmbedderDashscope.EmbedStrings(ctx, texts)
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
func InitAllVector(ctx context.Context) {
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
// ModelType: model.ModelTypeVector.Code(),
//})
//if err != nil {
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
// return
//}
//
//for _, l := range list {
// err = InitVector(ctx, l)
// if err != nil {
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
// continue
// }
//}
modelDO := new(entity.Model)
modelDO.TenantId = 1
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
var cfg entity.VectorModelConfigDashScope
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
cfg.Model = "text-embedding-v3"
modelDO.ConfigContent = gconv.Map(&cfg)
err := InitVector(ctx, modelDO)
if err != nil {
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
return
}
return nil, fmt.Errorf("unsupported provider: %v", provider)
}
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
set := &EmbedderSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
// 解析 Ark 向量配置
var cfg entity.VectorModelConfigArk
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ark向量配置失败: %v", err)
}
arkCfg := &ark.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if !g.IsEmpty(cfg.APIType) {
arkCfg.APIType = new(ark.APIType(cfg.APIType))
}
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
case *model.ModelConfigTypeVectorOllama.Code():
// 解析 Ollama 向量配置
var cfg entity.VectorModelConfigOllama
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
}
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorOpenAI.Code():
// 解析 OpenAI 向量配置
var cfg entity.VectorModelConfigOpenAI
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
}
openaiCfg := &openai.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
}
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
case *model.ModelConfigTypeVectorQianfan.Code():
// 解析 千帆 向量配置
var cfg entity.VectorModelConfigQianfan
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析千帆向量配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorTencentCloud.Code():
// 解析 腾讯云 向量配置
var cfg entity.VectorModelConfigTencentCloud
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
}
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
SecretID: cfg.SecretID,
SecretKey: cfg.SecretKey,
Region: cfg.Region,
})
case *model.ModelConfigTypeVectorDashScope.Code():
// 解析 阿里 dashscope 向量配置
var cfg entity.VectorModelConfigDashScope
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
}
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
})
default:
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
}
// 统一错误处理
if err != nil {
return fmt.Errorf("初始化向量模型失败: %v", err)
}
// 直接存入 map无锁重复初始化会直接覆盖
tenantEmbedders[modelDO.TenantId] = set
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
return
}
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
set := tenantEmbedders[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
}
return set, nil
}
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantEmbedder(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeVectorArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeVectorOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeVectorOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeVectorQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeVectorTencentCloud.Code():
return set.TencentCloud, nil
case *model.ModelConfigTypeVectorDashScope.Code():
return set.DashScope, nil
default:
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
}
}
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
delete(tenantEmbedders, modelDO.TenantId)
return InitVector(ctx, modelDO)
}

View File

@@ -1,273 +0,0 @@
/*
* Copyright 2024 Red Future Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eino
import (
"context"
"fmt"
"net/http"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gclient"
"github.com/gogf/gf/v2/util/gconv"
)
var (
// 千问API默认配置
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
defaultTimeout = 10 * time.Minute
defaultRetryTimes = 2
)
type QwenEmbeddingConfig struct {
// Timeout specifies the maximum duration to wait for API responses
// Optional. Default: 10 minutes
Timeout *time.Duration `json:"timeout"`
// HTTPClient specifies the client to send HTTP requests.
// Optional. Default &http.Client{Timeout: Timeout}
HTTPClient *http.Client `json:"http_client"`
// RetryTimes specifies the number of retry attempts for failed API calls
// Optional. Default: 2
RetryTimes *int `json:"retry_times"`
// BaseURL specifies the base URL for Qwen DashScope service
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
BaseURL string `json:"base_url"`
// APIKey specifies the API Key for authentication
// Required
APIKey string `json:"api_key"`
// Model specifies the model name for Qwen embedding
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
Model string `json:"model"`
// TextType specifies the type of text: "document" or "query"
// Optional. Default: "document"
TextType string `json:"text_type"`
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
// Optional. Default: 5
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
}
type QwenEmbedder struct {
client *gclient.Client
conf *QwenEmbeddingConfig
}
// EmbeddingRequest 千问embedding请求结构
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
// EmbeddingResponse 千问embedding响应结构
type EmbeddingResponse struct {
Output struct {
Embeddings []struct {
TextIndex int `json:"text_index"`
Embedding []float64 `json:"embedding"`
} `json:"embeddings"`
} `json:"output"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
RequestID string `json:"request_id"`
}
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
func (e *APIError) Error() string {
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
}
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
if len(config.BaseURL) == 0 {
config.BaseURL = defaultBaseURL
}
if config.Timeout == nil {
config.Timeout = &defaultTimeout
}
if config.RetryTimes == nil {
defaultRetryTimes := 2
config.RetryTimes = &defaultRetryTimes
}
if len(config.TextType) == 0 {
config.TextType = "document"
}
if config.MaxConcurrentRequests == nil {
defaultMaxConcurrentRequests := 5
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
}
client := g.Client()
client.SetTimeout(*config.Timeout)
return client
}
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
if len(config.APIKey) == 0 {
return nil, fmt.Errorf("[Qwen] APIKey is required")
}
if len(config.Model) == 0 {
return nil, fmt.Errorf("[Qwen] Model is required")
}
client := buildQwenClient(config)
return &QwenEmbedder{
client: client,
conf: config,
}, nil
}
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
[][]float64, error) {
if len(texts) == 0 {
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
}
options := embedding.GetCommonOptions(&embedding.Options{
Model: &e.conf.Model,
}, opts...)
conf := &embedding.Config{
Model: dereferenceOrZero(options.Model),
}
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
Texts: texts,
Config: conf,
})
defer func() {
if err := recover(); err != nil {
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
}
}()
var usage *embedding.TokenUsage
var embeddings [][]float64
var err error
// 调用千问API获取embedding
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
Embeddings: embeddings,
Config: conf,
TokenUsage: usage,
})
return embeddings, nil
}
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
// 构建请求
var req EmbeddingRequest
req.Model = e.conf.Model
req.Input.Texts = texts
req.Parameters.TextType = e.conf.TextType
// 调用API
client := e.client.Clone()
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
client.SetHeader("Content-Type", "application/json")
client.SetTimeout(*e.conf.Timeout)
resp, err := client.Post(ctx, e.conf.BaseURL, req)
if err != nil {
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
}
defer resp.Close()
// 检查状态码
if resp.StatusCode != http.StatusOK {
var errResp APIError
result := resp.ReadAll()
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
return nil, nil, &errResp
}
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
}
// 解析响应
var apiResp EmbeddingResponse
result := resp.ReadAll()
if err = gconv.Struct(result, &apiResp); err != nil {
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
}
// 解析响应结果
embeddings := make([][]float64, len(texts))
for _, emb := range apiResp.Output.Embeddings {
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
embeddings[emb.TextIndex] = emb.Embedding
}
}
usage := &embedding.TokenUsage{
TotalTokens: apiResp.Usage.TotalTokens,
}
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
return embeddings, usage, nil
}
func (e *QwenEmbedder) GetType() string {
return getType()
}
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
return true
}
func getType() string {
return "Qwen"
}
func dereferenceOrZero[T any](v *T) T {
if v == nil {
var t T
return t
}
return *v
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
@@ -34,7 +35,14 @@ func NewPGVectorIndexer(opts *PGVectorIndexerOptions) *PGVectorIndexer {
return &PGVectorIndexer{opts: opts}
}
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, opts ...indexer.Option) (rows int64, err error) {
func (i *PGVectorIndexer) Store(ctx context.Context, docs []*schema.Document, configType model.ModelConfigType, opts ...indexer.Option) (rows int64, err error) {
embedderByType, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return
}
indexer.WithEmbedding(embedderByType)
commonOpts := indexer.GetCommonOptions(&indexer.Options{}, opts...)
if commonOpts.Embedding == nil {
@@ -100,9 +108,9 @@ func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document,
}
// 转成业务实体
var chunks []*dto.VectorDocumentChunkMsg
var chunks []*dto.VectorDocumentVectorMsg
for idx, doc := range docs {
ck := new(dto.VectorDocumentChunkMsg)
ck := new(dto.VectorDocumentVectorMsg)
err = gconv.Struct(doc.MetaData, ck)
if err != nil {
glog.Errorf(ctx, "doStore err: %v", err)
@@ -126,7 +134,7 @@ func (i *PGVectorIndexer) doStore(ctx context.Context, docs []*schema.Document,
return
}
// 入库
rows, err = dao.DocumentChunk.BatchInsert(ctx, chunks)
rows, err = dao.DocumentVector.BatchInsert(ctx, chunks)
return
}

View File

@@ -1,11 +0,0 @@
package eino
// TaskPriority 任务优先级
type TaskPriority string
const (
TaskPriorityLow TaskPriority = "low" // 低优先级
TaskPriorityMedium TaskPriority = "medium" // 中优先级
TaskPriorityHigh TaskPriority = "high" // 高优先级
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
)

116
common/eino/rerank.go Normal file
View File

@@ -0,0 +1,116 @@
package eino
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
)
// DashScopeReranker 通义百炼 Rerank 精排Cross-Encoder
type DashScopeReranker struct {
httpClient *http.Client
}
func NewDashScopeReranker() *DashScopeReranker {
return &DashScopeReranker{
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// Rerank 对文档进行精排Cross-Encoder 核心)
func (d *DashScopeReranker) Rerank(ctx context.Context, query string, docs []*schema.Document) ([]*schema.Document, error) {
if len(docs) == 0 {
return docs, nil
}
// 官方必过 URL
url := "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
apiKey := g.Cfg().MustGet(ctx, "eino.rerank.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.rerank.model").String()
documents := make([]string, len(docs))
for i, doc := range docs {
documents[i] = doc.Content
}
reqBody := map[string]any{
"model": model,
"input": map[string]any{
"query": query,
"documents": documents,
},
"parameters": map[string]any{
"top_n": len(docs),
},
}
bs, _ := json.Marshal(reqBody)
req, _ := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(bs))
req.Header.Set("Authorization", "Bearer "+apiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := d.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("rerank api error: status=%d, body=%s", resp.StatusCode, string(body))
}
// 解析结果
var result struct {
Output struct {
Results []struct {
Index int `json:"index"`
RelevanceScore float64 `json:"relevance_score"`
} `json:"results"`
} `json:"output"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
// 按分数排序
type scoredDoc struct {
doc *schema.Document
score float64
}
scored := make([]scoredDoc, len(docs))
for i, doc := range docs {
scored[i] = scoredDoc{doc: doc, score: 0}
}
for _, res := range result.Output.Results {
scored[res.Index].score = res.RelevanceScore
}
// 分数从高到低排序
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].score > scored[i].score {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
// 输出最终排好的文档
ranked := make([]*schema.Document, 0, len(scored))
for _, s := range scored {
s.doc.MetaData["rerank_score"] = s.score
ranked = append(ranked, s.doc)
}
return ranked, nil
}

View File

@@ -3,11 +3,18 @@ package eino
import (
"context"
"errors"
"fmt"
"rag/consts/model"
"rag/dao"
"sort"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
@@ -16,63 +23,173 @@ type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
DSLInfo map[string]any
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
dslInfo map[string]any
reranker *DashScopeReranker // 通义精排
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
if config.Embedder == nil {
return nil, errors.New("embedder is required")
}
func NewPGVectorRetriever(ctx context.Context, config *PGVectorRetrieverConfig, configType model.ModelConfigType) (*PGVectorRetriever, error) {
if config.DefaultTopK <= 0 {
config.DefaultTopK = 5
}
e, err := GetTenantEmbedderByType(ctx, configType)
if err != nil {
return nil, err
}
return &PGVectorRetriever{
embedder: config.Embedder,
embedder: e,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
//reranker: NewDashScopeReranker(), // 👈 直接初始化你的精排
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
// 1. 处理公共 Option官方标准写法
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
DSLInfo: r.dslInfo,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
// 2. 回调(官方标准)
// 安全保护:防止 nil 指针 panic
topK := 10
if options.TopK != nil {
topK = *options.TopK
}
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// 3. 执行检索
docs, err := r.doRetrieve(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
// ==========================================
// 🔥 优化版grpool 并行双路检索(安全、健壮、无泄漏)
// ==========================================
var (
docsVector []*schema.Document
docsFulltext []*schema.Document
errVector error
errFulltext error
// 缓冲通道=2确保无死锁等待
done = make(chan struct{}, 2)
)
// 上下文:超时 + 可取消双保障建议5s超时根据业务调整
taskCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// 封装并行任务函数,消除重复代码
runTask := func(task func() error, errTarget *error) {
defer func() {
// 任务结束必发信号,确保通道不阻塞
done <- struct{}{}
}()
// 捕获 panic + 执行业务逻辑
g.TryCatch(taskCtx, func(ctx context.Context) {
*errTarget = task()
}, func(ctx context.Context, panicErr error) {
*errTarget = panicErr
})
// 任务失败:立即取消另一个任务(快速失败)
if *errTarget != nil {
cancel()
}
}
// ----------------------
// 并行提交两个检索任务
// ----------------------
// 任务1向量检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsVector, errVector = r.doRetrieveVector(ctx, query, options)
return errVector
}, &errVector)
})
// 任务2全文检索
grpool.Add(taskCtx, func(ctx context.Context) {
runTask(func() error {
docsFulltext, errFulltext = r.doRetrieveMeilisearch(ctx, query, options)
return errFulltext
}, &errFulltext)
})
// ----------------------
// 安全等待所有任务完成
// ----------------------
<-done
<-done
// ----------------------
// 统一错误处理
// ----------------------
// 用 errors.Join 合并所有错误,不丢失信息
if err := errors.Join(errVector, errFulltext); err != nil {
return nil, err
}
// 4. 完成回调
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
Docs: docs,
})
// 合并 + 智能去重(保留最优分数)
mergedDocs := mergeAndDeduplicate(docsVector, docsFulltext)
return docs, nil
// =========================
// 🔥 Cross-Encoder 精排
// =========================
var finalDocs []*schema.Document
if r.reranker != nil {
ranked, err := r.reranker.Rerank(ctx, query, mergedDocs)
if err != nil {
return nil, fmt.Errorf("rerank failed: %w", err)
}
finalDocs = ranked
} else {
sort.Slice(mergedDocs, func(i, j int) bool {
d1 := gconv.Float64(mergedDocs[i].MetaData["distance"])
d2 := gconv.Float64(mergedDocs[j].MetaData["distance"])
return d1 < d2
})
finalDocs = mergedDocs
}
// =========================
// 过滤无效文档
// =========================
const maxDistance = 0.8
validDocs := make([]*schema.Document, 0, len(finalDocs))
for _, doc := range finalDocs {
dist := gconv.Float64(doc.MetaData["distance"])
if dist <= maxDistance {
validDocs = append(validDocs, doc)
}
}
// 最多保留 topK
if len(validDocs) > topK {
validDocs = validDocs[:topK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: validDocs})
return validDocs, nil
}
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
// 1. 生成向量
// ==========================================
// 1. 向量检索PG
// ==========================================
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
@@ -81,37 +198,103 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(vectors[0])
topK := *opts.TopK
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := 10
if opts.TopK != nil {
topK = *opts.TopK
}
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
// 2. PG 向量相似度检索 SQL
sql := `
SELECT id, content, dataset_id, document_id,
vector <-> ? AS distance
FROM document_chunk
ORDER BY distance ASC
LIMIT ?
`
// 3. 查询
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
rows, err := dao.DocumentVector.GetAllByVector(ctx, datasetIds, documentIds, queryVec, topK)
if err != nil {
return nil, err
}
// 4. 转为 Eino Document
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
Metadata: map[string]any{
"dataset_id": row["dataset_id"],
"document_id": row["document_id"],
"distance": row["distance"],
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": gconv.Float64(row["distance"]),
"retrieve_by": "vector",
},
})
}
return docs, nil
}
// ==========================================
// 2. 全文检索Meilisearch🔥 新增
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
var datasetIds, documentIds []int64
if g.IsEmpty(opts.DSLInfo["dataset_ids"]) {
datasetIds = gconv.Int64s(opts.DSLInfo["dataset_ids"])
}
if g.IsEmpty(opts.DSLInfo["document_ids"]) {
documentIds = gconv.Int64s(opts.DSLInfo["document_ids"])
}
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentVector.SearchByKeywords(ctx, query, datasetIds, documentIds, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
score := gconv.Float64(row["_rankingScore"])
distance := score
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": distance,
"retrieve_by": "fulltext",
},
})
}
return docs, nil
}
// ==========================================
// 合并去重(智能版:两路都命中时,保留向量结果 + 全文标记)
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
// 先存入向量结果
for _, d := range vecDocs {
idMap[d.ID] = d
}
// 再处理全文:不存在则添加;存在则标记“双路命中”,不覆盖向量分数
for _, d := range fullDocs {
if existDoc, ok := idMap[d.ID]; ok {
// 标记同时被向量和全文检索到
existDoc.MetaData["retrieve_by"] = "both"
} else {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)
}
return merged
}

View File

@@ -1,12 +0,0 @@
package eino
// TaskStatus 任务状态
type TaskStatus string
const (
TaskStatusPending TaskStatus = "pending" // 待处理
TaskStatusRunning TaskStatus = "running" // 运行中
TaskStatusCompleted TaskStatus = "completed" // 已完成
TaskStatusFailed TaskStatus = "failed" // 失败
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
)

View File

@@ -1,14 +0,0 @@
package eino
// TaskType 任务类型
type TaskType string
const (
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
)

View File

@@ -1,114 +0,0 @@
package gse
import (
"context"
"sort"
"github.com/go-ego/gse"
"github.com/go-ego/gse/hmm/extracker"
"github.com/go-ego/gse/hmm/segment"
"github.com/gogf/gf/v2/os/glog"
)
var GseTool *gseTool
// 初始化函数:程序启动时执行一次
func init() {
var err error
GseTool, err = newGseTool()
if err != nil {
glog.Error(context.Background(), err)
}
}
// gseTool 关键词提取工具gse v1.0.2 标准)
type gseTool struct {
seg gse.Segmenter
tfidf *extracker.TagExtracter
tr *extracker.TextRanker
}
// newGseTool 初始化工具(内置词典 + 停用词)
func newGseTool() (tool *gseTool, err error) {
// 1. 初始化分词器
var seg gse.Segmenter
// 内置词典(无外部文件)
err = seg.LoadDictEmbed()
if err != nil {
return
}
// 内置停用词v1.0.2 标准)
err = seg.LoadStopEmbed()
if err != nil {
return
}
// 2. 初始化 TF-IDF 提取器
tfidf := &extracker.TagExtracter{}
tfidf.WithGse(seg)
err = tfidf.LoadIdf()
if err != nil {
return
}
// 3. 初始化 TextRank 提取器
tr := &extracker.TextRanker{}
tr.WithGse(seg)
tool = &gseTool{
seg: seg,
tfidf: tfidf,
tr: tr,
}
return
}
// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM
func (k *gseTool) Cut(text string) []string {
return k.seg.Cut(text, true)
}
// Keyword 最终输出:关键词 + 权重
type Keyword struct {
Word string `json:"word"`
Score float64 `json:"score"`
}
func (k *gseTool) Extract(text string, topN int) []Keyword {
// 1. 提取 TF-IDF
tfTags := k.extractTFIDF(text, topN)
// 2. 提取 TextRank
trTags := k.extractTextRank(text, topN)
// 3. 合并成最终关键词(业务最常用)
scoreMap := make(map[string]float64)
for _, tag := range tfTags {
scoreMap[tag.Text] = tag.Weight
}
for _, tag := range trTags {
scoreMap[tag.Text] = tag.Weight
}
// 转成切片并排序(高分在前)
res := make([]Keyword, 0, len(scoreMap))
for word, score := range scoreMap {
res = append(res, Keyword{Word: word, Score: score})
}
sort.Slice(res, func(i, j int) bool {
return res[i].Score > res[j].Score
})
return res
}
// ExtractTFIDF TF-IDF 关键词带权重90% 业务:文章标签、搜索、关键词
func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments {
return k.tfidf.ExtractTags(text, topN)
}
// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解
func (k *gseTool) extractTextRank(text string, topN int) segment.Segments {
return k.tr.TextRank(text, topN)
}

View File

@@ -1,57 +1,15 @@
server:
address: :3006
name: rag
workerId: 1
# Database.
database:
default:
- type: "pgsql"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
name: "rag"
prefix: "rag_knowledge_" # (可选)表名前缀
role: "master" # (可选)数据库主从角色(master/slave)默认为master。如果不使用应用主从机制请不配置或留空即可。
debug: true # (可选)开启调试模式
dryRun: false # (可选)ORM空跑(只读不写)
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312)一般设置为utf8mb4。默认为utf8。
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
maxIdleConnTime: "30s" # (可选v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置避免长时间空闲连接占用资源。
createdAt: "created_at" # (可选)自动创建时间字段名称
updatedAt: "updated_at" # (可选)自动更新时间字段名称
deletedAt: "deleted_at" # (可选)软删除时间字段名称
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
- type: "pgsql"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "Bjang09@686^*^"
name: "tenant-1"
prefix: "rag_knowledge_" # (可选)表名前缀
role: "slave" # (可选)数据库主从角色(master/slave)默认为master。如果不使用应用主从机制请不配置或留空即可。
debug: false # (可选)开启调试模式
dryRun: false # (可选)ORM空跑(只读不写)
charset: "utf8" # (可选)数据库编码(如: utf8mb4/utf8/gbk/gb2312)一般设置为utf8mb4。默认为utf8。
timezone: "Asia/Shanghai" # (可选)时区配置,例如:Local
maxIdle: 5 # (可选)连接池最大闲置的连接数(默认10)
maxOpen: 20 # (可选)连接池最大打开的连接数(默认无限制)
maxLifetime: "30s" # (可选)连接对象可重复使用的时间长度(默认30秒)
maxIdleConnTime: "30s" # (可选v2.10新增)连接池中空闲连接的最大生存时间(默认30秒)。可以通过配置文件或SetConnMaxIdleTime方法设置避免长时间空闲连接占用资源。
createdAt: "created_at" # (可选)自动创建时间字段名称
updatedAt: "updated_at" # (可选)自动更新时间字段名称
deletedAt: "deleted_at" # (可选)软删除时间字段名称
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
rag_knowledge:
- type: "pgsql"
host: "localhost"
port: "5432"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "123456"
pass: "Bjang09@686^*^"
name: "tenant-1"
prefix: "rag_knowledge_" # (可选)表名前缀
role: "master"
@@ -69,10 +27,10 @@ database:
timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性为true时CreatedAt/UpdatedAt/DeletedAt都将失效
rag_vector:
- type: "pgsql"
host: "localhost"
port: "5432"
host: "116.204.74.41"
port: "15432"
user: "postgres"
pass: "123456"
pass: "Bjang09@686^*^"
name: "tenant-1"
prefix: "rag_vector_" # (可选)表名前缀
role: "master"
@@ -91,22 +49,22 @@ database:
redis:
default:
address: "localhost:6379"
address: "116.204.74.41:6379"
db: 0
consul:
address: localhost:8500
address: 116.204.74.41:8500
jaeger:
addr: localhost:4318
addr: 116.204.74.41:4318
# eino框架配置
eino:
# 文件切分配置
splitter:
bufferSize: 1
minChunkSize: 64
percentile: 0.75
bufferSize: 3 # 必须 >=3 才能识别上下文语义
minChunkSize: 1 # 避免切碎
percentile: 0.75 # 保持不变
# 向量化配置
embedding:
provider: "dashscope"
@@ -115,23 +73,18 @@ eino:
# apiType: "multi_modal_api"
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
model: "text-embedding-v3"
chatmodel:
provider: "dashscope"
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
model: "qwen-turbo"
rerank:
provider: "dashscope"
apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
model: "qwen3-rerank"
# 文件上传服务地址与oss模块minio中的endpoint一致
filePrefix: "http://116.204.74.41:9000"
gmq:
redis:
primary:
addr: "localhost"
port: "6379"
db: 0
username: ""
password: ""
poolSize: 10
minIdleConn: 5
maxActiveConn: 10
maxRetries: 30
# Meilisearch 全文检索配置
meilisearch:
default:

26
consts/keyword/type.go Normal file
View File

@@ -0,0 +1,26 @@
package keyword
import "github.com/gogf/gf/v2/util/gconv"
var (
KeywordTypeDefined = newKeywordType(gconv.PtrInt8(1), "自定义")
KeywordTypeInitial = newKeywordType(gconv.PtrInt8(2), "初始化")
)
type KeywordType *int8
type keywordType struct {
code KeywordType
desc string
}
func (s keywordType) Code() KeywordType {
return s.code
}
func (s keywordType) Desc() string {
return s.desc
}
func newKeywordType(code KeywordType, desc string) keywordType {
return keywordType{code: code, desc: desc}
}

132
consts/model/config_type.go Normal file
View File

@@ -0,0 +1,132 @@
package model
import (
"github.com/gogf/gf/v2/util/gconv"
)
var (
ModelConfigTypeVectorArk = newModelConfigType(gconv.PtrString("ark"), "字节跳动火山引擎方舟大模型服务")
ModelConfigTypeVectorOllama = newModelConfigType(gconv.PtrString("ollama"), "Ollama 本地大模型运行工具")
ModelConfigTypeVectorOpenAI = newModelConfigType(gconv.PtrString("openAI"), "OpenAI 官方大模型服务")
ModelConfigTypeVectorQianfan = newModelConfigType(gconv.PtrString("qianfan"), "百度文心一言千帆大模型平台")
ModelConfigTypeVectorTencentCloud = newModelConfigType(gconv.PtrString("tencentCloud"), "腾讯云大模型服务")
ModelConfigTypeVectorDashScope = newModelConfigType(gconv.PtrString("dashScope"), "阿里云通义千问 DashScope 平台")
ModelConfigTypeChatArk = newModelConfigType(gconv.PtrString("ark"), "字节跳动火山引擎方舟大模型服务")
ModelConfigTypeChatArkBot = newModelConfigType(gconv.PtrString("arkBot"), "火山引擎 ARK Bot 智能体服务")
ModelConfigTypeChatClaude = newModelConfigType(gconv.PtrString("claude"), "Anthropic Claude 系列大模型")
ModelConfigTypeChatDeepSeek = newModelConfigType(gconv.PtrString("deepSeek"), "DeepSeek 深度求索大模型")
ModelConfigTypeChatOllama = newModelConfigType(gconv.PtrString("ollama"), "Ollama 本地大模型运行工具")
ModelConfigTypeChatOpenAI = newModelConfigType(gconv.PtrString("openAI"), "OpenAI 官方大模型服务")
ModelConfigTypeChatQianfan = newModelConfigType(gconv.PtrString("qianfan"), "百度文心一言千帆大模型平台")
ModelConfigTypeChatQwen = newModelConfigType(gconv.PtrString("qwen"), "腾讯文心千问大模型平台")
)
type ModelConfigType *string
type modelConfigType struct {
code ModelConfigType
desc string
}
func (s modelConfigType) Code() ModelConfigType {
return s.code
}
func (s modelConfigType) Desc() string {
return s.desc
}
func newModelConfigType(code ModelConfigType, desc string) modelConfigType {
return modelConfigType{code: code, desc: desc}
}
func GetVectorDescByCode(code ModelConfigType) string {
switch *code {
case *ModelConfigTypeVectorArk.Code():
return ModelConfigTypeVectorArk.Desc()
case *ModelConfigTypeVectorOllama.Code():
return ModelConfigTypeVectorOllama.Desc()
case *ModelConfigTypeVectorOpenAI.Code():
return ModelConfigTypeVectorOpenAI.Desc()
case *ModelConfigTypeVectorQianfan.Code():
return ModelConfigTypeVectorQianfan.Desc()
case *ModelConfigTypeVectorTencentCloud.Code():
return ModelConfigTypeVectorTencentCloud.Desc()
case *ModelConfigTypeVectorDashScope.Code():
return ModelConfigTypeVectorDashScope.Desc()
}
return "未知类型"
}
func GetChatDescByCode(code ModelConfigType) string {
switch *code {
case *ModelConfigTypeChatArk.Code():
return ModelConfigTypeChatArk.Desc()
case *ModelConfigTypeChatArkBot.Code():
return ModelConfigTypeChatArkBot.Desc()
case *ModelConfigTypeChatClaude.Code():
return ModelConfigTypeChatClaude.Desc()
case *ModelConfigTypeChatDeepSeek.Code():
return ModelConfigTypeChatDeepSeek.Desc()
case *ModelConfigTypeChatOllama.Code():
return ModelConfigTypeChatOllama.Desc()
case *ModelConfigTypeChatOpenAI.Code():
return ModelConfigTypeChatOpenAI.Desc()
case *ModelConfigTypeChatQianfan.Code():
return ModelConfigTypeChatQianfan.Desc()
case *ModelConfigTypeChatQwen.Code():
return ModelConfigTypeChatQwen.Desc()
}
return "未知类型"
}
type GetModelConfigTypeEnumRes struct {
Options []ConfigTypeKeyValue `json:"options"`
}
type ConfigTypeKeyValue struct {
Key interface{} `json:"key"` // 对应原有常量值
Value interface{} `json:"value"` // 对应描述信息
}
func GetAllModelConfigTypeEnums(modelType ModelType) *GetModelConfigTypeEnumRes {
// 枚举列表
var list []modelConfigType
if *modelType == *ModelTypeVector.Code() {
list = []modelConfigType{
ModelConfigTypeVectorArk,
ModelConfigTypeVectorOllama,
ModelConfigTypeVectorOpenAI,
ModelConfigTypeVectorQianfan,
ModelConfigTypeVectorTencentCloud,
ModelConfigTypeVectorDashScope,
}
}
if *modelType == *ModelTypeChat.Code() {
list = []modelConfigType{
ModelConfigTypeChatArk,
ModelConfigTypeChatArkBot,
ModelConfigTypeChatClaude,
ModelConfigTypeChatDeepSeek,
ModelConfigTypeChatOllama,
ModelConfigTypeChatOpenAI,
ModelConfigTypeChatQianfan,
ModelConfigTypeChatQwen,
}
}
// 组装返回格式
options := make([]ConfigTypeKeyValue, 0, len(list))
for _, item := range list {
options = append(options, ConfigTypeKeyValue{
Key: *item.Code(),
Value: item.Desc(),
})
}
return &GetModelConfigTypeEnumRes{
Options: options,
}
}

66
consts/model/type.go Normal file
View File

@@ -0,0 +1,66 @@
package model
import "github.com/gogf/gf/v2/util/gconv"
var (
ModelTypeVector = newModelType(gconv.PtrString("vector"), "向量")
ModelTypeChat = newModelType(gconv.PtrString("chat"), "对话")
)
type ModelType *string
type modelType struct {
code ModelType
desc string
}
func (s modelType) Code() ModelType {
return s.code
}
func (s modelType) Desc() string {
return s.desc
}
func newModelType(code ModelType, desc string) modelType {
return modelType{code: code, desc: desc}
}
func GetModelTypeDescByCode(code ModelType) string {
switch *code {
case *ModelTypeVector.Code():
return ModelTypeVector.Desc()
case *ModelTypeChat.Code():
return ModelTypeChat.Desc()
}
return "未知类型"
}
type GetModelTypeEnumRes struct {
Options []TypeKeyValue `json:"options"`
}
type TypeKeyValue struct {
Key interface{} `json:"key"` // 对应原有常量值
Value interface{} `json:"value"` // 对应描述信息
}
func GetAllModelTypeEnums() *GetModelTypeEnumRes {
// 枚举列表
list := []modelType{
//ModelTypeVector,
ModelTypeChat,
}
// 组装返回格式
options := make([]TypeKeyValue, 0, len(list))
for _, item := range list {
options = append(options, TypeKeyValue{
Key: *item.Code(),
Value: item.Desc(),
})
}
return &GetModelTypeEnumRes{
Options: options,
}
}

15
consts/public/msg_key.go Normal file
View File

@@ -0,0 +1,15 @@
package public
const GmqMsgPluginsName = "gmq_msg"
const KnowledgeLockEsKey = "knowledge:lock:knowledgeIdEs-%v"
const KnowledgeLockSqlKey = "knowledge:lock:knowledgeIdSql-%v"
const KnowledgeContentHashEsKey = "knowledge:knowledgeId:contentHashEs-%v"
const KnowledgeContentHashSqlKey = "knowledge:knowledgeId:contentHashSql-%v"
const (
KnowledgeDocumentVectorTopic = "knowledge:document:vector:stream" // 请求 Stream 键名与发消息的key一致
KnowledgeDocumentVectorConsumer = "knowledge-document-vector-consumer" // 消费者名称(唯一标识)
KnowledgeDocumentVectorCount = 1 // 批处理大小每次读取1条
KnowledgeDocumentVectorAutoAck = false // ACK是否自动确认true自动确认false不确认
)

View File

@@ -1,20 +0,0 @@
package public
const KnowledgeLockEsKey = "rag:knowledge:lock:knowledgeIdEs-%v"
const KnowledgeLockSqlKey = "rag:knowledge:lock:knowledgeIdSql-%v"
const KnowledgeContentHashEsKey = "rag:knowledge:knowledgeId:contentHashEs-%v"
const KnowledgeContentHashSqlKey = "rag:knowledge:knowledgeId:contentHashSql-%v"
const (
KnowledgeDocumentVectorStatusTopic = "knowledge:document:vector:status:stream"
KnowledgeDocumentVectorStatusConsumer = "knowledge-document-vector-status-consumer"
KnowledgeDocumentVectorStatusBatchSize = 1
KnowledgeDocumentVectorStatusAutoAck = false
)
const (
KnowledgeDocumentChunkTopic = "knowledge:document:chunk:stream" // 请求 Stream 键名与发消息的key一致
KnowledgeDocumentChunkConsumer = "knowledge-document-chunk-consumer" // 消费者名称(唯一标识)
KnowledgeDocumentChunkBatchSize = 1 // 批处理大小每次读取1条
KnowledgeDocumentChunkAutoAck = false // ACK是否自动确认true自动确认false不确认
)

View File

@@ -1,5 +1,6 @@
package public
// 数据库名称
const (
DbNameKnowledge = "rag_knowledge"
DbNameVector = "rag_vector"
@@ -7,11 +8,13 @@ const (
// sql 数据库表名
const (
TableNameDocument = "document"
TableNameDataset = "dataset"
TableNameKeyword = "keyword"
TableNameDatasetIndex = "dataset_index"
TableNameDocumentChunk = "document_chunk"
TableNameDocument = "document"
TableNameDataset = "dataset"
TableNameKeyword = "keyword"
TableNameTask = "task"
TableNameModel = "model"
TableNameDatasetIndex = "dataset_index"
TableNameDocumentVector = "document_vector"
)
// es 索引名称

30
consts/task/consts.go Normal file
View File

@@ -0,0 +1,30 @@
package task
// TaskType 任务类型枚举:文档解析的三个子任务
type TaskType string
const (
TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词
TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量
TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索
TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务
)
// TaskStatus 任务状态枚举
type TaskStatus string
const (
TaskStatusPending TaskStatus = "PENDING" // 待执行
TaskStatusRunning TaskStatus = "RUNNING" // 执行中
TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成
TaskStatusFailed TaskStatus = "FAILED" // 执行失败
)
// TaskPriority 任务优先级
type TaskPriority int
const (
TaskPriorityLow TaskPriority = 1 // 低
TaskPriorityMedium TaskPriority = 2 // 中
TaskPriorityHigh TaskPriority = 3 // 高
)

View File

@@ -6,7 +6,7 @@ import (
"rag/model/dto"
"rag/service"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
@@ -40,9 +40,3 @@ func (c *dataset) List(ctx context.Context, req *dto.ListDatasetReq) (res *dto.L
res, err = service.Dataset.List(ctx, req)
return
}
// Search 搜索
//func (c *dataset) Search(ctx context.Context, req *dto.SearchReq) (res *dto.SearchRes, err error) {
// res, err = service.Dataset.Search(ctx, req)
// return
//}

View File

@@ -2,11 +2,10 @@ package controller
import (
"context"
"rag/model/dto"
"rag/service"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
@@ -33,7 +32,7 @@ func (c *document) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (res
}
// Get 获取文件详情
func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
func (c *document) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
res, err = service.Document.Get(ctx, req)
return
}
@@ -47,8 +46,23 @@ func (c *document) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto
return
}
// Process 处理文件(向量化)
func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
res, err = service.Document.Process(ctx, req)
// DocumentVector 处理文件(向量化)
func (c *document) DocumentVector(ctx context.Context, req *dto.DocumentVectorReq) (res *beans.ResponseEmpty, err error) {
err = service.Document.Vector(ctx, req)
return
}
func (c *document) VectorSemanticSplit(ctx context.Context, req *dto.VectorSemanticSplitReq) (res *beans.ResponseEmpty, err error) {
err = service.Document.VectorSemanticSplit(ctx, req)
return
}
func (c *document) SearchRecursiveSplit(ctx context.Context, req *dto.SearchRecursiveSplitReq) (res *beans.ResponseEmpty, err error) {
err = service.Document.SearchRecursiveSplit(ctx, req)
return
}
func (c *document) KeywordExtract(ctx context.Context, req *dto.KeywordExtractReq) (res *beans.ResponseEmpty, err error) {
err = service.Document.KeywordExtract(ctx, req)
return
}

View File

@@ -1,29 +0,0 @@
package controller
import (
"context"
"rag/model/dto"
"rag/service"
"gitea.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
type documentChunk struct{}
var DocumentChunk = new(documentChunk)
// Update 更新文件片段
func (c *documentChunk) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (res *beans.ResponseEmpty, err error) {
err = service.DocumentChunk.Update(ctx, req)
return
}
// List 文件片段列表
func (c *documentChunk) List(ctx context.Context, req *dto.ListDocumentChunkReq) (res *dto.ListDocumentChunkRes, err error) {
if !g.IsEmpty(req.Page) {
req.Page = &beans.Page{PageNum: 1, PageSize: 20}
}
res, err = service.DocumentChunk.List(ctx, req)
return
}

View File

@@ -0,0 +1,35 @@
package controller
import (
"context"
"rag/model/dto"
"rag/service"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
type documentVector struct{}
var DocumentVector = new(documentVector)
// Query 执行RAG查询
func (c *documentVector) Query(ctx context.Context, req *dto.RAGQueryReq) (res *dto.RAGQueryRes, err error) {
res, err = service.DocumentVector.Query(ctx, req)
return
}
// Update 更新文件片段
func (c *documentVector) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (res *beans.ResponseEmpty, err error) {
err = service.DocumentVector.Update(ctx, req)
return
}
// List 文件片段列表
func (c *documentVector) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
if !g.IsEmpty(req.Page) {
req.Page = &beans.Page{PageNum: 1, PageSize: 20}
}
res, err = service.DocumentVector.List(ctx, req)
return
}

View File

@@ -2,11 +2,10 @@ package controller
import (
"context"
"rag/model/dto"
"rag/service"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)

52
controller/model.go Normal file
View File

@@ -0,0 +1,52 @@
package controller
import (
"context"
"rag/model/dto"
"rag/service"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
type model struct{}
var Model = new(model)
func (c *model) GetModelTypeEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
res, err = service.ModelService.GetModelAllEnums(ctx, req)
return
}
func (c *model) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (res *dto.GetModelConfigFormFieldsRes, err error) {
res, err = service.ModelService.GetModelConfigFormFields(ctx, req)
return
}
func (c *model) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
res, err = service.ModelService.Create(ctx, req)
return
}
func (c *model) Update(ctx context.Context, req *dto.UpdateModelReq) (res *beans.ResponseEmpty, err error) {
err = service.ModelService.Update(ctx, req)
return
}
func (c *model) Delete(ctx context.Context, req *dto.DeleteModelReq) (res *beans.ResponseEmpty, err error) {
err = service.ModelService.Delete(ctx, req)
return
}
func (c *model) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
res, err = service.ModelService.Get(ctx, req)
return
}
func (c *model) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
if !g.IsEmpty(req.Page) {
req.Page = &beans.Page{PageNum: 1, PageSize: 20}
}
res, err = service.ModelService.List(ctx, req)
return
}

16
controller/task.go Normal file
View File

@@ -0,0 +1,16 @@
package controller
import (
"context"
"rag/model/dto"
"rag/service"
)
type task struct{}
var Task = new(task)
func (c *task) Get(ctx context.Context, req *dto.GetTaskReq) (res *dto.ListTaskRes, err error) {
res, err = service.Task.Get(ctx, req)
return
}

View File

@@ -6,7 +6,7 @@ import (
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"

View File

@@ -7,7 +7,7 @@ import (
"rag/consts/public"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
)
var DatasetIndex = new(datasetIndexDao)
@@ -49,8 +49,9 @@ func (d *datasetIndexDao) InsertIndex(ctx context.Context, indexName string) (er
CREATE INDEX IF NOT EXISTS %s
ON %s
USING ivfflat (vector vector_cosine_ops)
WITH (lists = 100)
WHERE vector IS NOT NULL;
`, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk)
`, indexName, gfdb.TablePrefix+public.TableNameDocumentVector)
_, err = db.Exec(ctx, sqlStr)
return
}

View File

@@ -6,7 +6,7 @@ import (
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
@@ -48,16 +48,28 @@ func (d *documentDao) Update(ctx context.Context, req *dto.UpdateDocumentReq) (r
// Delete 删除文件
func (d *documentDao) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Delete()
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).OmitEmpty().
Where(entity.DocumentCol.Id, req.Id).
Delete()
if err != nil {
return
}
return r.RowsAffected()
}
// GetByID 根据ID获取文件
func (d *documentDao) GetByID(ctx context.Context, req *dto.GetDocumentReq, fields ...string) (res *entity.Document, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Where(entity.DocumentCol.Id, req.Id).Fields(fields).One()
func (d *documentDao) Count(ctx context.Context, req *dto.ListDocumentReq) (count int, err error) {
count, err = gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).OmitEmpty().
Where(entity.DocumentCol.DatasetId, req.DatasetId).
Where(entity.DocumentCol.Title, req.Title).Count()
return
}
// Get 根据ID获取文件
func (d *documentDao) Get(ctx context.Context, req *dto.GetDocumentReq, fields ...string) (res *entity.Document, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameDocument).Fields(fields).OmitEmpty().
Where(entity.DocumentCol.Id, req.Id).
Where(entity.DocumentCol.Title, req.Title).
Where(entity.DocumentCol.DatasetId, req.DatasetId).One()
if err != nil {
return
}

View File

@@ -1,57 +0,0 @@
package dao
import (
"context"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var DocumentChunk = new(documentChunkDao)
type documentChunkDao struct{}
// BatchInsert 批量插入文件块
func (d *documentChunkDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentChunkMsg) (rows int64, err error) {
var res []*entity.DocumentChunk
if err = gconv.Structs(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Data(&res).Insert()
if err != nil {
return
}
return r.RowsAffected()
}
// Update 更新文件块
func (d *documentChunkDao) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk)
r, err := model.Data(&req).Where(entity.DocumentChunkCol.Id, req.Id).Update()
if err != nil {
return
}
return r.RowsAffected()
}
// List 文件块列表
func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkReq, fields ...string) (res []*entity.DocumentChunk, total int, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentChunk).Fields(fields).OmitEmpty().
Where(entity.DocumentChunkCol.DatasetId, req.DatasetId).
Where(entity.DocumentChunkCol.DocumentId, req.DocumentId).
Where(entity.DocumentChunkCol.Status, req.Status).
Where(entity.DocumentChunkCol.VectorStatus, req.VectorStatus).
OrderDesc(entity.DocumentChunkCol.CreatedAt)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}

152
dao/document_vector.go Normal file
View File

@@ -0,0 +1,152 @@
package dao
import (
"context"
"fmt"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/full-text-search/meilisearch"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
var DocumentVector = new(documentVectorDao)
type documentVectorDao struct{}
// BatchInsert 批量插入文件块
func (d *documentVectorDao) BatchInsert(ctx context.Context, req []*dto.VectorDocumentVectorMsg) (rows int64, err error) {
var res []*entity.DocumentVector
if err = gconv.Structs(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Data(&res).Insert()
if err != nil {
return
}
return r.RowsAffected()
}
// Update 更新文件块
func (d *documentVectorDao) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty()
r, err := model.Data(&req).Where(entity.DocumentVectorCol.Id, req.Id).Update()
if err != nil {
return
}
return r.RowsAffected()
}
func (d *documentVectorDao) Delete(ctx context.Context, req *dto.DeleteDocumentVectorReq) (rows int64, err error) {
result, err := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).OmitEmpty().
Where(entity.DocumentVectorCol.Id, req.Id).
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
Delete()
if err != nil {
return
}
return result.RowsAffected()
}
// List 文件块列表
func (d *documentVectorDao) List(ctx context.Context, req *dto.ListDocumentVectorReq, fields ...string) (res []*entity.DocumentVector, total int, err error) {
model := gfdb.DB(ctx, public.DbNameVector).Model(ctx, public.TableNameDocumentVector).Fields(fields).OmitEmpty().
Where(entity.DocumentVectorCol.DatasetId, req.DatasetId).
Where(entity.DocumentVectorCol.DocumentId, req.DocumentId).
Where(entity.DocumentVectorCol.Status, req.Status).
Where(entity.DocumentVectorCol.VectorStatus, req.VectorStatus).
WhereIn(entity.DocumentVectorCol.DocumentId, req.DocumentIds)
if !g.IsEmpty(req.Keyword) {
model.WhereLike(entity.DocumentVectorCol.Content, "%"+req.Keyword+"%")
}
model.OrderAsc(entity.DocumentVectorCol.ChunkIndex)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}
func (d *documentVectorDao) GetAllByVector(ctx context.Context, datasetIds, documentIds []int64, vector pgvector.Vector, topK int) (list gdb.List, err error) {
// 动态拼接 WHERE 条件
var whereCondition string
var queryParams []interface{}
// 优先使用 documentIds 查询
if len(documentIds) > 0 {
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DocumentId)
queryParams = append(queryParams, documentIds)
}
if len(datasetIds) > 0 {
whereCondition = fmt.Sprintf(" AND %s IN (?) ", entity.DocumentVectorCol.DatasetId)
queryParams = append(queryParams, datasetIds)
}
// 完整 SQL
sql := `
SELECT id, content, dataset_id, document_id,vector <=> ? AS distance
FROM rag_vector_document_vector
WHERE 1=1 ` + whereCondition + ` AND vector IS NOT NULL ORDER BY distance ASC LIMIT ?`
// 拼接参数vector + 条件参数 + topK
queryParams = append([]interface{}{vector}, queryParams...)
queryParams = append(queryParams, topK)
// 执行查询
result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryParams...)
if err != nil {
return nil, err
}
return result.List(), nil
}
// SearchByKeywords 通过关键词全文检索文档块
func (d *documentVectorDao) SearchByKeywords(ctx context.Context, query string, datasetIds, documentIds []int64, topK int) (list gdb.List, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Query: query,
Limit: int64(topK),
ShowRankingScore: true,
}
if len(datasetIds) > 0 {
datasetIdStrs := gconv.Strings(datasetIds)
quotedIds := make([]string, len(datasetIdStrs))
for i, id := range datasetIdStrs {
quotedIds[i] = fmt.Sprintf("%s", id)
}
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DatasetId, gstr.Implode(", ", quotedIds))
}
if len(documentIds) > 0 {
documentIdStrs := gconv.Strings(documentIds)
quotedIds := make([]string, len(documentIdStrs))
for i, id := range documentIdStrs {
quotedIds[i] = fmt.Sprintf("%s", id)
}
searchParams.Filter = fmt.Sprintf("%s IN [%s]", entity.DocumentVectorCol.DocumentId, gstr.Implode(", ", quotedIds))
}
// 执行搜索
var hits []map[string]interface{}
_, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits)
if err != nil {
return nil, err
}
// 转换查询结果为 gdb.List
resultList := make(gdb.List, 0, len(hits))
for _, hit := range hits {
resultList = append(resultList, hit)
}
return resultList, nil
}

View File

@@ -6,7 +6,7 @@ import (
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -44,7 +44,7 @@ func (d *keywordDao) BatchSaveOrUpdate(ctx context.Context, req []*dto.CreateKey
}
func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword)
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).OmitEmpty()
r, err := model.Data(&req).Where(entity.KeywordCol.Id, req.Id).Update()
if err != nil {
return
@@ -53,7 +53,10 @@ func (d *keywordDao) Update(ctx context.Context, req *dto.UpdateKeywordReq) (row
}
func (d *keywordDao) Delete(ctx context.Context, req *dto.DeleteKeywordReq) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).Where(entity.KeywordCol.Id, req.Id).Delete()
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameKeyword).OmitEmpty().
Where(entity.KeywordCol.Id, req.Id).
Where(entity.KeywordCol.DocumentId, req.DocumentId).
Delete()
if err != nil {
return
}
@@ -82,6 +85,9 @@ func (d *keywordDao) List(ctx context.Context, req *dto.ListKeywordReq, fields .
if !g.IsEmpty(req.Keyword) {
model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%")
}
model.WhereIn(entity.KeywordCol.Word, req.Words)
model.Where(entity.KeywordCol.DatasetId, req.DatasetId)
model.Where(entity.KeywordCol.DocumentId, req.DocumentId)
model.OrderDesc(entity.KeywordCol.Weight)
model.OrderDesc(entity.KeywordCol.CreatedAt)
if req.Page != nil {

88
dao/model.go Normal file
View File

@@ -0,0 +1,88 @@
package dao
import (
"context"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Model = new(modelDao)
type modelDao struct{}
func (d *modelDao) Insert(ctx context.Context, req *dto.CreateModelReq) (id int64, err error) {
var res *entity.Model
if err = gconv.Struct(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).Data(&res).Insert()
if err != nil {
return
}
return r.LastInsertId()
}
func (d *modelDao) Update(ctx context.Context, req *dto.UpdateModelReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).OmitEmpty()
r, err := model.Data(&req).Where(entity.ModelCol.Id, req.Id).Update()
if err != nil {
return
}
return r.RowsAffected()
}
func (d *modelDao) Delete(ctx context.Context, req *dto.DeleteModelReq) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).Where(entity.ModelCol.Id, req.Id).Delete()
if err != nil {
return
}
return r.RowsAffected()
}
func (d *modelDao) Count(ctx context.Context, req *dto.GetModelReq) (count int, err error) {
count, err = gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).OmitEmpty().Where(entity.ModelCol.ModelType, req.ModelType).Count()
return
}
func (d *modelDao) Get(ctx context.Context, req *dto.GetModelReq, fields ...string) (res *entity.Model, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).Fields(fields).OmitEmpty().
Where(entity.ModelCol.Id, req.Id).
Where(entity.ModelCol.ModelType, req.ModelType).One()
if err != nil {
return
}
err = r.Struct(&res)
return
}
func (d *modelDao) GetNoTenantId(ctx context.Context, req *dto.GetModelReq, fields ...string) (res []*entity.Model, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).NoTenantId(ctx).Where(entity.ModelCol.ModelType, req.ModelType).Fields(fields).All()
if err != nil {
return
}
err = r.Structs(&res)
return
}
func (d *modelDao) List(ctx context.Context, req *dto.ListModelReq, fields ...string) (res []*entity.Model, total int, err error) {
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameModel).Fields(fields).OmitEmpty()
if !g.IsEmpty(req.ModelName) {
model.WhereLike(entity.ModelCol.ModelName, "%"+req.ModelName+"%")
}
model.Where(entity.ModelCol.ModelType, req.ModelType)
model.OrderDesc(entity.KeywordCol.CreatedAt)
if req.Page != nil {
model.Page(int(req.Page.PageNum), int(req.Page.PageSize))
}
r, total, err := model.AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}

67
dao/task.go Normal file
View File

@@ -0,0 +1,67 @@
package dao
import (
"context"
"rag/consts/public"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = new(taskDao)
type taskDao struct{}
// Insert 创建任务
func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, err error) {
var res *entity.Task
if err = gconv.Struct(req, &res); err != nil {
return
}
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Data(&res).Insert()
if err != nil {
return
}
return r.LastInsertId()
}
// Update 更新任务
func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) {
model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty()
r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).Update()
if err != nil {
return
}
return r.RowsAffected()
}
func (d *taskDao) Count(ctx context.Context, req *dto.GetTaskReq) (count int, err error) {
count, err = gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.TaskCol.TaskId, req.TaskId).
Where(entity.TaskCol.TaskType, req.TaskType).
Where(entity.TaskCol.Status, req.TaskStatus).
Count()
return
}
func (d *taskDao) Get(ctx context.Context, req *dto.GetTaskReq) (res []*entity.Task, total int, err error) {
r, total, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty().
Where(entity.TaskCol.Id, req.Id).
Where(entity.TaskCol.TaskId, req.TaskId).
Where(entity.TaskCol.TaskType, req.TaskType).AllAndCount(false)
if err != nil {
return
}
err = r.Structs(&res)
return
}
func (d *taskDao) DeleteByTaskId(ctx context.Context, req *dto.DeleteTaskByTaskIdReq) (rows int64, err error) {
r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Where(entity.TaskCol.TaskId, req.TaskId).Delete()
if err != nil {
return
}
return r.RowsAffected()
}

172
go.mod
View File

@@ -1,89 +1,112 @@
module rag
go 1.26.0
go 1.26.1
require (
gitea.com/red-future/common v0.0.11
github.com/bjang03/gmq v0.0.0-00010101000000-000000000000
github.com/cloudwego/eino v0.8.6
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.1
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419
github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9
github.com/cloudwego/eino-ext/components/model/ark v0.1.65
github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9
github.com/elastic/go-elasticsearch/v8 v8.16.0
github.com/go-ego/gse v1.0.2
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0
github.com/gogf/gf/v2 v2.10.0
github.com/golang/glog v1.2.5
github.com/pgvector/pgvector-go v0.3.0
gitea.redpowerfuture.com/red-future/common v0.0.23
github.com/bjang03/gmq v0.0.1
github.com/cloudwego/eino v0.9.5
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/loader/url v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/parser/docx v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/parser/pdf v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/parser/xlsx v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/embedding/ark v0.1.2
github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/embedding/ollama v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/embedding/qianfan v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/embedding/tencentcloud v0.0.0-20260610041010-8fe947165ad0
github.com/cloudwego/eino-ext/components/model/ark v0.1.68
github.com/cloudwego/eino-ext/components/model/arkbot v0.1.2
github.com/cloudwego/eino-ext/components/model/claude v0.1.19
github.com/cloudwego/eino-ext/components/model/deepseek v0.1.6
github.com/cloudwego/eino-ext/components/model/ollama v0.1.9
github.com/cloudwego/eino-ext/components/model/openai v0.1.13
github.com/cloudwego/eino-ext/components/model/qianfan v0.1.4
github.com/cloudwego/eino-ext/components/model/qwen v0.1.9
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
github.com/gogf/gf/v2 v2.10.2
github.com/pgvector/pgvector-go v0.4.0
)
//replace gitea.com/red-future/common v0.0.11 => ../common
replace github.com/bjang03/gmq => ../gmq
require (
github.com/BurntSushi/toml v1.6.0 // indirect
cloud.google.com/go/auth v0.7.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.3 // indirect
cloud.google.com/go/compute/metadata v0.7.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/PuerkitoBio/goquery v1.8.1 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/andybalholm/cascadia v1.3.1 // indirect
github.com/anthropics/anthropic-sdk-go v1.26.0 // indirect
github.com/armon/go-metrics v0.4.1 // indirect
github.com/aws/aws-sdk-go-v2 v1.33.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.3 // indirect
github.com/aws/aws-sdk-go-v2/config v1.29.1 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.54 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.11 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.10 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 // indirect
github.com/aws/smithy-go v1.22.1 // indirect
github.com/aymerick/douceur v0.2.0 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/baidubce/bce-qianfan-sdk/go/qianfan v0.0.14 // indirect
github.com/baidubce/bce-sdk-go v0.9.164 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bwmarrin/snowflake v0.3.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cenkalti/backoff/v4 v4.1.2 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/clbanning/mxj/v2 v2.7.0 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
github.com/cohesion-org/deepseek-go v1.3.4 // indirect
github.com/dgraph-io/badger/v4 v4.2.0 // indirect
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dslipak/pdf v0.0.2 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/eino-contrib/docx2md v0.0.1 // indirect
github.com/eino-contrib/jsonschema v1.0.3 // indirect
github.com/elastic/elastic-transport-go/v8 v8.10.0 // indirect
github.com/eino-contrib/ollama v0.1.0 // indirect
github.com/emirpasic/gods/v2 v2.0.0-alpha // indirect
github.com/evanphx/json-patch v0.5.2 // indirect
github.com/fatih/color v1.19.0 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
github.com/go-ego/gse v1.0.2 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/goccy/go-json v0.10.6 // indirect
github.com/gogf/gf/contrib/nosql/redis/v2 v2.9.1 // indirect
github.com/gogf/gf/contrib/registry/consul/v2 v2.9.5 // indirect
github.com/gogf/gf/contrib/trace/otlphttp/v2 v2.9.5 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.1 // indirect
github.com/golang/glog v1.2.5 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/flatbuffers v1.12.1 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/goph/emperror v0.17.2 // indirect
github.com/gorilla/css v1.0.1 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/grokify/html-strip-tags-go v0.1.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/hashicorp/consul/api v1.26.1 // indirect
@@ -94,75 +117,102 @@ require (
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-rootcerts v1.0.2 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/serf v0.10.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/joho/godotenv v1.5.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.11 // indirect
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.12.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.21 // indirect
github.com/meguminnnnnnnnn/go-openai v0.1.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/meguminnnnnnnnn/go-openai v0.1.2 // indirect
github.com/meilisearch/meilisearch-go v0.36.1 // indirect
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/nats-io/nats.go v1.49.0 // indirect
github.com/nats-io/nkeys v0.4.15 // indirect
github.com/nats-io/nuid v1.0.1 // indirect
github.com/nikolalohinski/gonja v1.5.3 // indirect
github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect
github.com/olekukonko/errors v1.2.0 // indirect
github.com/olekukonko/ll v0.1.8 // indirect
github.com/olekukonko/tablewriter v1.1.4 // indirect
github.com/olekukonko/errors v1.1.0 // indirect
github.com/olekukonko/ll v0.0.9 // indirect
github.com/olekukonko/tablewriter v1.1.0 // indirect
github.com/ollama/ollama v0.9.6 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/r3labs/diff/v2 v2.15.1 // indirect
github.com/rabbitmq/amqp091-go v1.10.0 // indirect
github.com/redis/go-redis/v9 v9.18.0 // indirect
github.com/richardlehane/mscfb v1.0.4 // indirect
github.com/richardlehane/msoleps v1.0.4 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.9 // indirect
github.com/spf13/viper v1.18.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1093 // indirect
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/hunyuan v1.0.1093 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tiger1103/gfast-token v1.0.10 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/vcaesar/cedar v0.30.0 // indirect
github.com/volcengine/volc-sdk-golang v1.0.199 // indirect
github.com/volcengine/volcengine-go-sdk v1.2.9 // indirect
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
github.com/volcengine/volc-sdk-golang v1.0.23 // indirect
github.com/volcengine/volcengine-go-sdk v1.2.30 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect
github.com/xuri/excelize/v2 v2.9.0 // indirect
github.com/xuri/nfp v0.0.0-20240318013403-ab9948c2c4a7 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
go.mongodb.org/mongo-driver/v2 v2.4.0 // indirect
go.opencensus.io v0.23.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel v1.42.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.42.0 // indirect
go.opentelemetry.io/otel/sdk v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.15.0 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/net v0.49.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.9.0 // indirect
google.golang.org/api v0.189.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

899
go.sum

File diff suppressed because it is too large Load Diff

45
main.go
View File

@@ -7,10 +7,12 @@ import (
"rag/consts/public"
"rag/controller"
"rag/service"
"strings"
"syscall"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/jaeger"
"gitea.redpowerfuture.com/red-future/common/http"
"gitea.redpowerfuture.com/red-future/common/jaeger"
"gitea.redpowerfuture.com/red-future/common/utils"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
@@ -26,30 +28,31 @@ func main() {
http.RouteRegister([]interface{}{
controller.Dataset,
controller.Document,
controller.DocumentChunk,
controller.DocumentVector,
controller.Model,
controller.Keyword,
controller.Task,
})
gmq.Init("config.yml")
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
SubMessage: types.SubMessage{
Topic: public.KnowledgeDocumentVectorStatusTopic,
ConsumerName: public.KnowledgeDocumentVectorStatusConsumer,
AutoAck: public.KnowledgeDocumentVectorStatusAutoAck,
FetchCount: public.KnowledgeDocumentVectorStatusBatchSize,
HandleFunc: service.Document.DocsVectorStatusMsg,
},
}); err != nil {
return
if err := utils.InitGseTool(ctx); err != nil {
g.Log().Error(ctx, "gse 分词工具初始化失败:", err)
}
if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{
redisAddress := g.Cfg().MustGet(ctx, "redis.default.address").String()
redisAddressList := strings.Split(redisAddress, ":")
gmq.GmqRegister(public.GmqMsgPluginsName, &mq.RedisConn{
RedisConfig: mq.RedisConfig{
Addr: redisAddressList[0],
Port: redisAddressList[1],
},
})
if err := gmq.GetGmq(public.GmqMsgPluginsName).GmqSubscribe(ctx, &mq.RedisSubMessage{
SubMessage: types.SubMessage{
Topic: public.KnowledgeDocumentChunkTopic,
ConsumerName: public.KnowledgeDocumentChunkConsumer,
AutoAck: public.KnowledgeDocumentChunkAutoAck,
FetchCount: public.KnowledgeDocumentChunkBatchSize,
HandleFunc: service.DocumentChunk.DocsChunkMsg,
Topic: public.KnowledgeDocumentVectorTopic,
ConsumerName: public.KnowledgeDocumentVectorConsumer,
AutoAck: public.KnowledgeDocumentVectorAutoAck,
FetchCount: public.KnowledgeDocumentVectorCount,
HandleFunc: service.DocumentVector.DocsChunkMsg,
},
}); err != nil {
return

View File

@@ -1,14 +1,14 @@
package dto
import (
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
// CreateDatasetReq 创建数据集请求
type CreateDatasetReq struct {
g.Meta `path:"/createDataset" method:"post" tags:"知识库(数据集)管理" summary:"创建知识库(数据集)" dc:"创建知识库(数据集)"`
g.Meta `path:"/create" method:"post" tags:"知识库(数据集)管理" summary:"创建知识库(数据集)" dc:"创建知识库(数据集)"`
Name string `json:"name" v:"required#名称不能为空"`
Description string `json:"description"`
@@ -21,7 +21,7 @@ type CreateDatasetRes struct {
// UpdateDatasetReq 更新数据集请求
type UpdateDatasetReq struct {
g.Meta `path:"/updateDataset" method:"put" tags:"知识库(数据集)管理" summary:"更新知识库(数据集)" dc:"更新知识库(数据集)"`
g.Meta `path:"/update" method:"put" tags:"知识库(数据集)管理" summary:"更新知识库(数据集)" dc:"更新知识库(数据集)"`
Id int64 `json:"id" v:"required#ID不能为空"`
Name string `json:"name"`
@@ -32,23 +32,24 @@ type UpdateDatasetReq struct {
// DeleteDatasetReq 删除数据集请求
type DeleteDatasetReq struct {
g.Meta `path:"/deleteDataset" method:"delete" tags:"知识库(数据集)管理" summary:"删除知识库(数据集)" dc:"删除知识库(数据集)"`
g.Meta `path:"/delete" method:"delete" tags:"知识库(数据集)管理" summary:"删除知识库(数据集)" dc:"删除知识库(数据集)"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
// GetDatasetReq 获取数据集请求
type GetDatasetReq struct {
g.Meta `path:"/getDataset" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)详情" dc:"获取知识库(数据集)详情"`
g.Meta `path:"/get" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)详情" dc:"获取知识库(数据集)详情"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
// ListDatasetReq 数据集列表请求
type ListDatasetReq struct {
g.Meta `path:"/listDataset" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)列表" dc:"分页查询知识库(数据集)列表,支持多条件筛选"`
g.Meta `path:"/list" method:"get" tags:"知识库(数据集)管理" summary:"获取知识库(数据集)列表" dc:"分页查询知识库(数据集)列表,支持多条件筛选"`
Page *beans.Page `json:"page"`
Ids []int64 `json:"ids" dc:"数据集ID列表"`
Keyword string `json:"keyword" dc:"关键词搜索"`
}

View File

@@ -3,14 +3,15 @@ package dto
import (
"rag/consts/document"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/pgvector/pgvector-go"
)
// CreateDocumentReq 创建文件请求
type CreateDocumentReq struct {
g.Meta `path:"/createDocument" method:"post" tags:"文件管理" summary:"创建文件" dc:"创建文件"`
g.Meta `path:"/create" method:"post" tags:"文件管理" summary:"创建文件" dc:"创建文件"`
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
Title string `json:"title" v:"required#标题不能为空"`
@@ -26,7 +27,7 @@ type CreateDocumentRes struct {
// UpdateDocumentReq 更新文件请求
type UpdateDocumentReq struct {
g.Meta `path:"/updateDocument" method:"put" tags:"文件管理" summary:"更新文件" dc:"更新文件"`
g.Meta `path:"/update" method:"put" tags:"文件管理" summary:"更新文件" dc:"更新文件"`
Id int64 `json:"id" v:"required#ID不能为空"`
Status document.Status `json:"status"`
@@ -36,25 +37,33 @@ type UpdateDocumentReq struct {
// DeleteDocumentReq 删除文件请求
type DeleteDocumentReq struct {
g.Meta `path:"/deleteDocument" method:"delete" tags:"文件管理" summary:"删除文件" dc:"删除文件"`
g.Meta `path:"/delete" method:"delete" tags:"文件管理" summary:"删除文件" dc:"删除文件"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
// GetDocumentReq 获取文件请求
type GetDocumentReq struct {
g.Meta `path:"/getDocument" method:"get" tags:"文件管理" summary:"获取文件详情" dc:"获取文件详情"`
g.Meta `path:"/get" method:"get" tags:"文件管理" summary:"获取文件详情" dc:"获取文件详情"`
Id int64 `json:"id" v:"required#ID不能为空"`
Id int64 `json:"id" v:"required#ID不能为空"`
DatasetId int64 `json:"datasetId"`
Title string `json:"title"`
}
type GetDocumentRes struct {
*DocumentVO
ImgAddressPrefix string `json:"imgAddressPrefix"`
}
// ListDocumentReq 文件列表请求
type ListDocumentReq struct {
g.Meta `path:"/listDocument" method:"get" tags:"文件管理" summary:"获取文件列表" dc:"分页查询文件列表,支持多条件筛选"`
g.Meta `path:"/list" method:"get" tags:"文件管理" summary:"获取文件列表" dc:"分页查询文件列表,支持多条件筛选"`
Page *beans.Page `json:"page"`
DatasetId int64 `json:"datasetId"`
Keyword string `json:"keyword" dc:"关键词搜索"`
Title string `json:"title" dc:"文件标题"`
Status document.Status `json:"status"`
}
@@ -68,6 +77,8 @@ type DocumentVO struct {
Id int64 `json:"id,string" dc:"id"`
DatasetId int64 `json:"datasetId,string"`
Title string `json:"title" dc:"文件标题"`
Format string `orm:"format" json:"format" dc:"文件格式"`
FilePath string `orm:"file_path" json:"filePath" dc:"文件存储路径"`
Status document.Status `json:"status" dc:"状态1启用/0停用"`
VectorStatus document.VectorStatus `json:"vectorStatus" dc:"向量化状态 状态: 1 待定, 2 处理, 3 完成, 4 失败"`
ChunkCount int64 `json:"chunkCount" dc:"分块数"`
@@ -76,33 +87,36 @@ type DocumentVO struct {
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
}
// ProcessDocumentReq 处理文件请求(向量化)
type ProcessDocumentReq struct {
g.Meta `path:"/getProcess" method:"get" tags:"文件管理" summary:"文件向量化处理" dc:"文件向量化处理"`
// DocumentVectorReq 处理文件请求(向量化)
type DocumentVectorReq struct {
g.Meta `path:"/vectorization" method:"post" tags:"文件管理" summary:"文件向量化处理" dc:"文件向量化处理"`
Id int64 `json:"id" v:"required#ID不能为空"`
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
}
// ProcessDocumentRes 处理文件响应
type ProcessDocumentRes struct {
ChunkCount int64 `json:"chunkCount"`
CostTime int64 `json:"costTime"`
type DocumentVectorRPC struct {
Id int64 `json:"id" dc:"id"`
DatasetId int64 `json:"datasetId" dc:"所属数据集ID"`
DocumentId int64 `json:"documentId" dc:"文件ID"`
ContentHash string `json:"contentHash" dc:"内容hash"`
Vector pgvector.Vector `json:"vector" dc:"向量"`
}
type ListDocumentChunkRPC struct {
List []*DocumentChunkRPC `json:"list"`
type VectorSemanticSplitReq struct {
g.Meta `path:"/vectorSemanticSplit" method:"post" tags:"文件管理" summary:"向量化生成" dc:"向量化生成"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
type DocumentChunkRPC struct {
Id int64 `json:"id" dc:"id"`
DatasetId int64 `json:"datasetId" dc:"所属数据集ID"`
ContentHash string `json:"contentHash" dc:"内容hash"`
type SearchRecursiveSplitReq struct {
g.Meta `path:"/searchRecursiveSplit" method:"post" tags:"文件管理" summary:"全文检索生成" dc:"全文检索生成"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
type KnowledgeDocumentMsg struct {
TenantId uint64 `json:"tenantId"`
Creator string `json:"creator"`
Id int64 `json:"id"`
VectorStatus document.VectorStatus `json:"vectorStatus"`
type KeywordExtractReq struct {
g.Meta `path:"/keywordExtract" method:"post" tags:"文件管理" summary:"关键词提取" dc:"关键词提取"`
Id int64 `json:"id" v:"required#ID不能为空"`
}

View File

@@ -3,38 +3,69 @@ package dto
import (
"rag/consts/document"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
"github.com/pgvector/pgvector-go"
)
// UpdateDocumentChunkReq 更新文件块向量请求
type UpdateDocumentChunkReq struct {
g.Meta `path:"/updateDocumentChunk" method:"put" tags:"文件块向量管理" summary:"更新文件块" dc:"更新文件块"`
// RAGQueryReq RAG查询请求
type RAGQueryReq struct {
g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"`
Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"`
DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"`
DocumentIds []int64 `json:"documentIds" dc:"文档ID"`
History []*Message `json:"history" dc:"历史对话"`
TopK int `json:"topK" d:"5" dc:"检索topK默认5"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// RAGQueryRes RAG查询响应
type RAGQueryRes struct {
Answer string `json:"answer" dc:"生成的答案"`
}
// UpdateDocumentVectorReq 更新文件块向量请求
type UpdateDocumentVectorReq struct {
g.Meta `path:"/update" method:"put" tags:"文件块向量管理" summary:"更新文件块" dc:"更新文件块"`
Id int64 `json:"id" v:"required#ID不能为空"`
Status document.Status `json:"status"`
}
// ListDocumentChunkReq 文件块向量列表请求
type ListDocumentChunkReq struct {
g.Meta `path:"/listDocumentChunk" method:"get" tags:"文件块向量管理" summary:"获取文件块向量列表" dc:"分页查询文件块向量列表,支持多条件筛选"`
type DeleteDocumentVectorReq struct {
g.Meta `path:"/delete" method:"put" tags:"文件块向量管理" summary:"删除文件块" dc:"删除文件块"`
Id int64 `json:"id"`
DocumentId int64 `json:"documentId"`
}
// ListDocumentVectorReq 文件块向量列表请求
type ListDocumentVectorReq struct {
g.Meta `path:"/list" method:"get" tags:"文件块向量管理" summary:"获取文件块向量列表" dc:"分页查询文件块向量列表,支持多条件筛选"`
Page *beans.Page `json:"page"`
Keyword string `json:"keyword" dc:"关键词搜索"`
DatasetId int64 `json:"datasetId"`
DocumentId int64 `json:"documentId"`
DocumentIds []int64 `json:"documentIds"`
ContentHashs []string `json:"contentHash"`
Status document.Status `json:"status"`
VectorStatus document.VectorStatus `json:"vectorStatus"`
}
// ListDocumentChunkRes 文件块向量列表响应
type ListDocumentChunkRes struct {
List []*DocumentChunkItem `json:"list"`
Total int `json:"total"`
// ListDocumentVectorRes 文件块向量列表响应
type ListDocumentVectorRes struct {
List []*DocumentVectorVO `json:"list"`
Total int `json:"total"`
}
type DocumentChunkItem struct {
type DocumentVectorVO struct {
Id int64 `json:"id,string" dc:"id"`
Status document.Status `json:"status" dc:"状态"`
VectorStatus document.VectorStatus `json:"vectorStatus" dc:"向量状态"`
@@ -49,7 +80,7 @@ type DocumentChunkItem struct {
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
}
type VectorDocumentChunkMsg struct {
type VectorDocumentVectorMsg struct {
TenantId uint64 `json:"tenantId"`
Creator string `json:"creator"`
DatasetId int64 `json:"datasetId"` // 数据集ID

View File

@@ -1,19 +1,22 @@
package dto
import (
"gitea.com/red-future/common/beans"
"rag/consts/keyword"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gtime"
)
// CreateKeywordReq 创建关键词请求
type CreateKeywordReq struct {
g.Meta `path:"/createKeyword" method:"post" tags:"关键词管理" summary:"创建关键词" dc:"创建关键词"`
g.Meta `path:"/create" method:"post" tags:"关键词管理" summary:"创建关键词" dc:"创建关键词"`
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"`
Word string `json:"word" v:"required#名称不能为空"`
Weight int16 `json:"weight" v:"required#权重不能为空"`
DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"`
DocumentId int64 `json:"documentId" v:"required#文档ID不能为空"`
Word string `json:"word" v:"required#名称不能为空"`
Weight int16 `json:"weight" v:"required#权重不能为空"`
KeywordType keyword.KeywordType `json:"keywordType" v:"required#类型不能为空"`
}
// CreateKeywordRes 创建关键词响应
@@ -23,7 +26,7 @@ type CreateKeywordRes struct {
// UpdateKeywordReq 更新关键词请求
type UpdateKeywordReq struct {
g.Meta `path:"/updateKeyword" method:"put" tags:"关键词管理" summary:"更新关键词" dc:"更新关键词"`
g.Meta `path:"/update" method:"put" tags:"关键词管理" summary:"更新关键词" dc:"更新关键词"`
Id int64 `json:"id" v:"required#ID不能为空"`
Word string `json:"word"`
@@ -32,27 +35,30 @@ type UpdateKeywordReq struct {
// DeleteKeywordReq 删除关键词请求
type DeleteKeywordReq struct {
g.Meta `path:"/deleteKeyword" method:"delete" tags:"关键词管理" summary:"删除关键词" dc:"删除关键词"`
g.Meta `path:"/delete" method:"delete" tags:"关键词管理" summary:"删除关键词" dc:"删除关键词"`
Id int64 `json:"id" v:"required#ID不能为空"`
Id int64 `json:"id"`
DocumentId int64 `json:"documentId"`
}
// GetKeywordReq 获取关键词请求
type GetKeywordReq struct {
g.Meta `path:"/getKeyword" method:"get" tags:"关键词管理" summary:"获取关键词详情" dc:"获取关键词详情"`
g.Meta `path:"/get" method:"get" tags:"关键词管理" summary:"获取关键词详情" dc:"获取关键词详情"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
// ListKeywordReq 关键词列表请求
type ListKeywordReq struct {
g.Meta `path:"/listKeyword" method:"get" tags:"关键词管理" summary:"获取关键词列表" dc:"分页查询关键词列表,支持多条件筛选"`
g.Meta `path:"/list" method:"get" tags:"关键词管理" summary:"获取关键词列表" dc:"分页查询关键词列表,支持多条件筛选"`
Page *beans.Page `json:"page"`
DatasetId int64 `json:"datasetId"`
DocumentId int64 `json:"documentId"`
Word string `json:"word"`
Keyword string `json:"keyword" dc:"关键词搜索"`
Page *beans.Page `json:"page"`
DatasetId int64 `json:"datasetId"`
DocumentId int64 `json:"documentId"`
Word string `json:"word"`
Words []string `json:"words"`
Keyword string `json:"keyword" dc:"关键词搜索"`
KeywordType keyword.KeywordType `json:"keywordType"`
}
// ListKeywordRes 关键词列表响应
@@ -62,9 +68,12 @@ type ListKeywordRes struct {
}
type KeywordVO struct {
Id int64 `json:"id,string" dc:"id"`
Word string `json:"word" dc:"关键词名称"`
Weight int16 `json:"weight" dc:"权重"`
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
Id int64 `json:"id,string" dc:"id"`
Word string `json:"word" dc:"关键词名称"`
Weight int16 `json:"weight" dc:"权重"`
KeywordType keyword.KeywordType `json:"keywordType" dc:"类型"`
DatasetId int64 `json:"datasetId,string" dc:"数据集ID"`
DocumentId int64 `json:"documentId,string" dc:"文档ID"`
CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"`
UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"`
}

114
model/dto/model.go Normal file
View File

@@ -0,0 +1,114 @@
package dto
import (
"rag/consts/model"
"time"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/gogf/gf/v2/frame/g"
)
type GetModelAllEnumsReq struct {
g.Meta `path:"/getAllEnums" method:"get" tags:"模型配置管理" summary:"获取全量模型枚举(类型+配置)"`
}
type GetModelEnumRes struct {
Options []ModelEnumOption `json:"options"`
}
// ModelEnumOption 主类型模型类型vector/chat
type ModelEnumOption struct {
Key interface{} `json:"key"`
Value interface{} `json:"value"`
ConfigTypes []ModelKeyValue `json:"configTypes"` // 这里统一!
}
// ModelKeyValue 统一的 KV 结构 → 给模型类型 + 配置类型共用
type ModelKeyValue struct {
Key interface{} `json:"key"`
Value interface{} `json:"value"`
}
// GetModelConfigFormFieldsReq 获取模型配置表单请求
type GetModelConfigFormFieldsReq struct {
g.Meta `path:"/getModelFormField" method:"get" tags:"模型配置管理" summary:"获取模型表单" dc:"获取模型表单列表"`
ModelType model.ModelType `json:"modelType"` // 模型类型 vector/chat
ConfigType model.ModelConfigType `json:"configType"` // 配置类型 ark/ollama/openai...
}
// GetModelConfigFormFieldsRes 获取模型配置表单响应
type GetModelConfigFormFieldsRes struct {
ModelType model.ModelType `json:"modelType"`
ConfigType model.ModelConfigType `json:"configType"`
Fields []map[string]interface{} `json:"fields"`
}
// CreateModelReq 创建模型请求
type CreateModelReq struct {
g.Meta `path:"/create" method:"post" tags:"模型配置管理" summary:"创建模型配置" dc:"创建模型配置"`
ModelType model.ModelType `json:"modelType" v:"required#模型类型不能为空"`
ModelName string `json:"modelName" v:"required#模型名称不能为空"`
ModelDesc string `json:"modelDesc"`
ConfigType model.ModelConfigType `json:"configType"`
ConfigContent map[string]interface{} `json:"configContent"`
}
// CreateModelRes 创建模型响应
type CreateModelRes struct {
Id int64 `json:"id,string"`
}
// UpdateModelReq 更新模型请求
type UpdateModelReq struct {
g.Meta `path:"/update" method:"put" tags:"模型配置管理" summary:"更新模型配置" dc:"更新模型配置"`
Id int64 `json:"id" v:"required#ID不能为空"`
ModelType model.ModelType `json:"modelType"`
ModelName string `json:"modelName"`
ModelDesc string `json:"modelDesc"`
ConfigType model.ModelConfigType `json:"configType"`
ConfigContent map[string]interface{} `json:"configContent"`
}
// DeleteModelReq 删除模型请求
type DeleteModelReq struct {
g.Meta `path:"/delete" method:"delete" tags:"模型配置管理" summary:"删除模型配置" dc:"删除模型配置"`
Id int64 `json:"id" v:"required#ID不能为空"`
}
// GetModelReq 获取模型请求
type GetModelReq struct {
g.Meta `path:"/get" method:"get" tags:"模型配置管理" summary:"获取模型配置详情" dc:"获取模型配置详情"`
Id int64 `json:"id"`
ModelType model.ModelType `json:"modelType"`
}
// ListModelReq 获取模型列表请求
type ListModelReq struct {
g.Meta `path:"/list" method:"get" tags:"模型配置管理" summary:"获取模型配置列表" dc:"分页查询模型配置列表,支持多条件筛选"`
Page *beans.Page `json:"page"`
ModelType model.ModelType `json:"modelType"`
ModelName string `json:"modelName"`
}
// ListModelRes 获取模型列表响应
type ListModelRes struct {
List []*ModelVO `json:"list"`
Total int `json:"total"`
}
type ModelVO struct {
Id int64 `json:"id,string"`
ModelType model.ModelType `json:"modelType"`
ModelName string `json:"modelName"`
ModelDesc string `json:"modelDesc"`
ConfigType model.ModelConfigType `json:"configType"`
ConfigContent map[string]interface{} `json:"configContent"`
CreateTime time.Time `json:"createTime"`
UpdateTime time.Time `json:"updateTime"`
}

75
model/dto/task.go Normal file
View File

@@ -0,0 +1,75 @@
package dto
import (
"rag/consts/task"
"github.com/gogf/gf/v2/frame/g"
)
// WriteTaskProgressReq 写入任务进度请求
type WriteTaskProgressReq struct {
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
Status task.TaskStatus `json:"status" dc:"任务状态"`
TaskId int64 `json:"taskId" dc:"任务ID"`
Remark string `json:"remark" dc:"备注"`
}
// CreateTaskReq 创建任务请求
type CreateTaskReq struct {
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
Status task.TaskStatus `json:"status" dc:"任务状态"`
TaskId int64 `json:"taskId" dc:"任务ID"`
Remark string `json:"remark" dc:"备注"`
}
// UpdateTaskReq 更新任务请求
type UpdateTaskReq struct {
Id int64 `json:"id" dc:"任务ID"`
TaskId int64 `json:"taskId" dc:"任务ID"`
Status task.TaskStatus `json:"status" dc:"任务状态"`
Remark string `json:"remark" dc:"备注"`
}
// DeleteTaskByTaskIdReq 删除任务请求
type DeleteTaskByTaskIdReq struct {
TaskId int64 `json:"taskId" v:"required#任务id不能为空"`
}
// GetTaskReq 获取任务请求
type GetTaskReq struct {
g.Meta `path:"/get" method:"get" tags:"任务管理" summary:"获取任务详情" dc:"获取任务详情"`
Id int64 `json:"id" dc:"任务ID"`
TaskId int64 `json:"taskId" dc:"任务ID"`
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
TaskStatus task.TaskStatus `json:"taskStatus" dc:"任务状态"`
}
type ListTaskRes struct {
List []*TaskVO `json:"list"`
Total int `json:"total"`
}
// TaskVO 任务视图对象
type TaskVO struct {
Id int64 `json:"id" dc:"任务ID"`
TaskType task.TaskType `json:"taskType" dc:"任务类型"`
Status task.TaskStatus `json:"status" dc:"任务状态"`
Priority task.TaskPriority `json:"priority" dc:"任务优先级"`
ParentTaskID int64 `json:"parentTaskId" dc:"父任务ID"`
TotalItems int64 `json:"totalItems" dc:"总项数"`
ProcessedItems int64 `json:"processedItems" dc:"已处理项数"`
Progress float64 `json:"progress" dc:"进度百分比"`
StartTime *int64 `json:"startTime" dc:"开始时间戳"`
EndTime *int64 `json:"endTime" dc:"结束时间戳"`
Duration int64 `json:"duration" dc:"耗时(毫秒)"`
SuccessCount int64 `json:"successCount" dc:"成功数"`
FailCount int64 `json:"failCount" dc:"失败数"`
Executor string `json:"executor" dc:"执行器"`
DocumentID int64 `json:"documentId" dc:"文档ID"`
Remark string `json:"remark" dc:"备注"`
Creator string `json:"creator" dc:"创建人"`
CreatedAt int64 `json:"createdAt" dc:"创建时间"`
Updater string `json:"updater" dc:"更新人"`
UpdatedAt int64 `json:"updatedAt" dc:"更新时间"`
}

View File

@@ -1,7 +1,7 @@
package entity
import (
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
)
type datasetCol struct {

View File

@@ -1,6 +1,6 @@
package entity
import "gitea.com/red-future/common/beans"
import "gitea.redpowerfuture.com/red-future/common/beans"
type datasetIndexCol struct {
beans.SQLBaseCol

View File

@@ -1,7 +1,7 @@
package entity
import (
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"rag/consts/document"
)

View File

@@ -3,11 +3,11 @@ package entity
import (
"rag/consts/document"
"gitea.com/red-future/common/beans"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/pgvector/pgvector-go"
)
type documentChunkCol struct {
type documentVectorCol struct {
beans.SQLBaseCol
Status string
VectorStatus string
@@ -20,7 +20,7 @@ type documentChunkCol struct {
Metadata string
}
var DocumentChunkCol = documentChunkCol{
var DocumentVectorCol = documentVectorCol{
SQLBaseCol: beans.DefSQLBaseCol,
Status: "status",
VectorStatus: "vector_status",
@@ -33,8 +33,8 @@ var DocumentChunkCol = documentChunkCol{
Metadata: "metadata",
}
// DocumentChunk 文档切分块实体
type DocumentChunk struct {
// DocumentVector 文档切分块实体
type DocumentVector struct {
beans.SQLBaseDO `orm:",inline"`
Status document.Status `orm:"status" json:"status" dc:"状态"`

View File

@@ -1,27 +1,34 @@
package entity
import "gitea.com/red-future/common/beans"
import (
"rag/consts/keyword"
"gitea.redpowerfuture.com/red-future/common/beans"
)
type keywordCol struct {
beans.SQLBaseCol
DatasetId string
DocumentId string
Word string
Weight string
DatasetId string
DocumentId string
Word string
Weight string
KeywordType string
}
var KeywordCol = keywordCol{
SQLBaseCol: beans.DefSQLBaseCol,
DatasetId: "dataset_id",
DocumentId: "document_id",
Word: "word",
Weight: "weight",
SQLBaseCol: beans.DefSQLBaseCol,
DatasetId: "dataset_id",
DocumentId: "document_id",
Word: "word",
Weight: "weight",
KeywordType: "keyword_type",
}
type Keyword struct {
beans.SQLBaseDO `orm:",inline"`
DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"`
DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"`
Word string `orm:"word" json:"word" dc:"关键词"`
Weight int16 `orm:"weight" json:"weight" dc:"权重"`
DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"`
DocumentId int64 `orm:"document_id" json:"documentId" dc:"文件ID"`
Word string `orm:"word" json:"word" dc:"关键词"`
Weight int16 `orm:"weight" json:"weight" dc:"权重"`
KeywordType keyword.KeywordType `orm:"keyword_type" json:"keywordType" dc:"类型"`
}

119
model/entity/model.go Normal file
View File

@@ -0,0 +1,119 @@
package entity
import (
"rag/consts/model"
"gitea.redpowerfuture.com/red-future/common/beans"
)
type modelCol struct {
beans.SQLBaseCol
DatasetId string
ModelType string
ModelName string
ModelDesc string
ConfigType string
ConfigContent string
}
var ModelCol = modelCol{
SQLBaseCol: beans.DefSQLBaseCol,
DatasetId: "dataset_id",
ModelType: "model_type",
ModelName: "model_name",
ModelDesc: "model_desc",
ConfigType: "config_type",
ConfigContent: "config_content",
}
type Model struct {
beans.SQLBaseDO `orm:",inline"`
DatasetId int64 `orm:"dataset_id" json:"datasetId" dc:"数据集ID"`
ModelType model.ModelType `orm:"model_type" json:"modelType" dc:"模型类型"` // 向量/对话
ModelName string `orm:"model_name" json:"modelName" dc:"模型名称"`
ModelDesc string `orm:"model_desc" json:"modelDesc" dc:"模型描述"`
ConfigType model.ModelConfigType `orm:"config_type" json:"configType" dc:"配置类型"` // ark/ollama等
ConfigContent map[string]interface{} `orm:"config_content" json:"configContent" dc:"配置详情"` // 存JSON
}
// -------------------------- 通用配置结构体(抽离重复字段)--------------------------
// OllamaConfig 通用配置(向量/对话完全一致)
type OllamaConfig struct {
BaseURL string `json:"base_url"`
Model string `json:"model"`
}
// OpenAIConfig 通用配置
type OpenAIConfig struct {
APIKey string `json:"api_key"`
Model string `json:"model"`
ByAzure bool `json:"by_azure"`
BaseURL string `json:"base_url"`
APIVersion string `json:"api_version"`
}
// QianfanConfig 千帆通用配置
type QianfanConfig struct {
AccessKey string `json:"access_key"`
SecretKey string `json:"secret_key"`
Model string `json:"model"`
}
// ArkConfig 通用配置
type ArkConfig struct {
APIKey string `json:"api_key"`
Model string `json:"model"`
}
// -------------------------- 向量模型配置 --------------------------
type VectorModelConfigOllama = OllamaConfig // 直接复用
type VectorModelConfigOpenAI = OpenAIConfig // 直接复用
type VectorModelConfigQianfan = QianfanConfig // 直接复用
type VectorModelConfigArk struct {
ArkConfig
APIType string `json:"api_type"`
}
type VectorModelConfigTencentCloud struct {
SecretID string `json:"secret_id"`
SecretKey string `json:"secret_key"`
Region string `json:"region"`
}
type VectorModelConfigDashScope struct {
APIKey string `json:"api_key"`
Model string `json:"model"`
}
// -------------------------- 对话模型配置 --------------------------
type ChatModelConfigArk = ArkConfig // 直接复用
type ChatModelConfigArkBot = ArkConfig // 直接复用
type ChatModelConfigOllama = OllamaConfig // 直接复用
type ChatModelConfigOpenAI = OpenAIConfig // 直接复用
type ChatModelConfigQianfan = QianfanConfig // 直接复用
type ChatModelConfigClaude struct {
ByBedrock bool `json:"by_bedrock"`
AccessKey string `json:"access_key"`
SecretAccessKey string `json:"secret_access_key"`
Region string `json:"region"`
APIKey string `json:"api_key"`
Model string `json:"model"`
BaseURL string `json:"base_url"`
}
type ChatModelConfigDeepSeek struct {
APIKey string `json:"api_key"`
Model string `json:"model"`
BaseURL string `json:"base_url"`
}
type ChatModelConfigQwen struct {
APIKey string `json:"api_key"`
Model string `json:"model"`
BaseURL string `json:"base_url"`
}

66
model/entity/task.go Normal file
View File

@@ -0,0 +1,66 @@
package entity
import (
"rag/consts/task"
"gitea.redpowerfuture.com/red-future/common/beans"
)
type taskCol struct {
beans.SQLBaseCol
TaskId string
TaskType string
Status string
Executor string
Remark string
//Priority string
//ParentTaskId string
//TotalItems string
//ProcessedItems string
//Progress string
//StartTime string
//EndTime string
//Duration string
//SuccessCount string
//FailCount string
}
var TaskCol = taskCol{
SQLBaseCol: beans.DefSQLBaseCol,
TaskId: "task_id",
TaskType: "task_type",
Status: "status",
Executor: "executor",
Remark: "remark",
//Priority: "priority",
//ParentTaskId: "parent_task_id",
//TotalItems: "total_items",
//ProcessedItems: "processed_items",
//Progress: "progress",
//StartTime: "start_time",
//EndTime: "end_time",
//Duration: "duration",
//SuccessCount: "success_count",
//FailCount: "fail_count",
}
// Task 任务记录表
type Task struct {
beans.SQLBaseDO `orm:",inline"`
TaskId int64 `orm:"task_id" json:"taskId" dc:"任务ID"`
TaskType task.TaskType `orm:"task_type" json:"taskType" dc:"任务类型"`
Status task.TaskStatus `orm:"status" json:"status" dc:"任务状态"`
Executor string `orm:"executor" json:"executor" dc:"执行器"`
Remark string `orm:"remark" json:"remark" dc:"备注"`
//Priority task.TaskPriority `orm:"priority" json:"priority" dc:"任务优先级"`
//ParentTaskId int64 `orm:"parent_task_id" json:"parentTaskId" dc:"父任务ID"`
//TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总项数"`
//ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理项数"`
//SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"`
//FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"`
//Progress float64 `orm:"progress" json:"progress" dc:"进度百分比"`
//StartTime *gtime.Time `orm:"start_time" json:"startTime" dc:"开始时间戳"`
//EndTime *gtime.Time `orm:"end_time" json:"endTime" dc:"结束时间戳"`
//Duration int64 `orm:"duration" json:"duration" dc:"耗时(毫秒)"`
}

View File

@@ -45,43 +45,3 @@ func (s *datasetService) List(ctx context.Context, req *dto.ListDatasetReq) (res
err = gconv.Struct(list, &res.List)
return
}
//// Search 搜索(示例,实际需要调用向量库)
//func (s *datasetService) Search(ctx context.Context, req *dto.SearchReq) (res *dto.SearchRes, err error) {
// // 1. 获取数据集信息
// kb, err := dao.Dataset.GetByID(ctx, req)
// if err != nil {
// return nil, err
// }
//
// // 2. 获取文件块
// chunks, err := dao.Chunk.FindChunksByKBIDWithLimit(ctx, req.KBID, 0, req.TopK)
// if err != nil {
// return nil, err
// }
//
// // 3. TODO: 使用向量检索(需要集成向量库)
// // 暂时使用简单的关键词匹配
// results := make([]dto.SearchResult, 0)
// for _, chunk := range chunks {
// results = append(results, dto.SearchResult{
// Content: chunk.Content,
// Score: 0.8, // TODO: 计算实际向量相似度
// DocumentID: chunk.DocumentID,
// ChunkIndex: chunk.Index,
// })
// }
//
// g.Log().Infof(ctx, "数据集[%s]搜索完成,查询:%s,结果数:%d", kb.Name, req.Query, len(results))
//
// return &dto.SearchRes{Results: results}, nil
//}
//
//// formatChunks 格式化文件块为上下文
//func (s *datasetService) formatChunks(chunks []*entity.DocumentChunk) string {
// var sb strings.Builder
// for i, chunk := range chunks {
// sb.WriteString(fmt.Sprintf("[%d] %s\n\n", i+1, chunk.Content))
// }
// return sb.String()
//}

View File

@@ -5,30 +5,28 @@ import (
"errors"
"fmt"
"rag/common/eino"
"rag/common/gse"
"rag/consts/document"
"rag/consts/keyword"
"rag/consts/model"
"rag/consts/public"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"strings"
"sync"
"time"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/db/gfdb"
"gitea.com/red-future/common/full-text-search/meilisearch"
"gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/full-text-search/meilisearch"
"gitea.redpowerfuture.com/red-future/common/utils"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/crypto/gmd5"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/database/gredis"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
"github.com/gogf/gf/v2/util/gconv"
)
@@ -38,7 +36,35 @@ type documentService struct{}
// Create 创建文件
func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq) (res *dto.CreateDocumentRes, err error) {
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
doc, err := dao.Document.Get(ctx, &dto.GetDocumentReq{
DatasetId: req.DatasetId,
Title: req.Title,
})
if err != nil {
return
}
if !g.IsEmpty(doc) && doc.Id > 0 {
_, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
DocumentId: doc.Id,
})
if err != nil {
return err
}
_, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
DocumentId: doc.Id,
})
if err != nil {
return err
}
_, err = dao.Document.Delete(ctx, &dto.DeleteDocumentReq{
Id: doc.Id,
})
if err != nil {
return err
}
}
var id int64
id, err = dao.Document.Insert(ctx, req)
if err != nil {
@@ -54,7 +80,16 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq
return
}
res = &dto.CreateDocumentRes{Id: id}
// 写入任务进度进行中 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusCompleted,
Remark: "文档上传完成",
})
if err != nil {
return
}
return
})
@@ -69,21 +104,42 @@ func (s *documentService) Update(ctx context.Context, req *dto.UpdateDocumentReq
// Delete 删除文件
func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq) (err error) {
docs, err := dao.Document.GetByID(ctx, &dto.GetDocumentReq{Id: req.Id})
docs, err := dao.Document.Get(ctx, &dto.GetDocumentReq{Id: req.Id})
if err != nil {
return
}
err = gfdb.DB(ctx).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) (err error) {
datasetReq := &dto.UpdateDatasetReq{
Id: docs.DatasetId,
DocumentCount: -1,
DocumentSize: -docs.FileSize,
}
_, err = dao.Dataset.Update(ctx, datasetReq)
if err != nil {
if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil {
return
}
_, err = dao.Document.Delete(ctx, req)
if _, err = dao.Document.Delete(ctx, req); err != nil {
return
}
if _, err = dao.Keyword.Delete(ctx, &dto.DeleteKeywordReq{
DocumentId: docs.Id,
}); err != nil {
return err
}
if _, err = dao.DocumentVector.Delete(ctx, &dto.DeleteDocumentVectorReq{
DocumentId: docs.Id,
}); err != nil {
return err
}
if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{
TaskId: docs.Id,
}); err != nil {
return
}
return
})
@@ -91,9 +147,17 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq
}
// Get 获取文件详情
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.DocumentVO, err error) {
r, err := dao.Document.GetByID(ctx, req)
err = gconv.Struct(r, &res)
func (s *documentService) Get(ctx context.Context, req *dto.GetDocumentReq) (res *dto.GetDocumentRes, err error) {
r, err := dao.Document.Get(ctx, req)
if err != nil {
return
}
res = &dto.GetDocumentRes{}
err = gconv.Struct(r, &res.DocumentVO)
if err != nil {
return
}
res.ImgAddressPrefix, err = utils.GetFileAddressPrefix(ctx)
return
}
@@ -107,234 +171,516 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r
Total: total,
}
err = gconv.Struct(list, &res.List)
//eino.TestIndexer()
//eino.TestRetriever()
return
}
// Process 处理文件(使用eino框架切分和向量化)
func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) {
startTime := time.Now()
func (s *documentService) VectorSemanticSplit(ctx context.Context, req *dto.VectorSemanticSplitReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.GetByID(ctx, &documentReq)
doc, err := dao.Document.Get(ctx, &documentReq)
if err != nil {
return nil, err
return err
}
if g.IsEmpty(doc) {
return nil, errors.New("document not found")
return errors.New("document not found")
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusRunning,
Remark: "向量化执行中",
})
return s.semanticSplitDocument(ctx, doc)
}
// 2. 使用eino框架进行文件切分并发执行
var vectorDocsCount, chunks int64
// 用 gopool 或者简单的错误等待,绝对不用裸 goroutine
var err1, err2, err3 error
var wg sync.WaitGroup
wg.Add(3)
func (s *documentService) SearchRecursiveSplit(ctx context.Context, req *dto.SearchRecursiveSplitReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.Get(ctx, &documentReq)
if err != nil {
return err
}
if g.IsEmpty(doc) {
return errors.New("document not found")
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusRunning,
Remark: "全文检索执行中",
})
return s.recursiveSplitDocument(ctx, doc)
}
// 任务1
go func() {
defer wg.Done()
vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc)
}()
// 任务2
go func() {
defer wg.Done()
err2 = s.esSplitDocument(ctx, doc)
}()
// 任务3
go func() {
defer wg.Done()
err3 = s.extractDocument(ctx, doc)
}()
// 直接等待,不使用通道,避免泄漏
wg.Wait()
func (s *documentService) KeywordExtract(ctx context.Context, req *dto.KeywordExtractReq) (err error) {
// 1. 查询文件信息
documentReq := dto.GetDocumentReq{Id: req.Id}
doc, err := dao.Document.Get(ctx, &documentReq)
if err != nil {
return err
}
if g.IsEmpty(doc) {
return errors.New("document not found")
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusRunning,
Remark: "提取关键词执行中",
})
return s.extractDocument(ctx, doc)
}
// Vector 处理文件(使用eino框架切分和向量化)
func (s *documentService) Vector(ctx context.Context, req *dto.DocumentVectorReq) (err error) {
// 更新文档状态为处理中
updateDocumentReq := new(dto.UpdateDocumentReq)
updateDocumentReq.Id = req.Id
// 统一判断错误
if err1 != nil || err2 != nil || err3 != nil {
// 更新文档状态
updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
return nil, err
}
if err1 != nil {
return nil, err1
}
if err2 != nil {
return nil, err2
}
return nil, err3
}
// 4. 更新文件状态为处理中和切分数量
if vectorDocsCount > 0 {
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
} else {
updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code()
}
updateDocumentReq.ChunkCount = chunks
updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code()
if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil {
// 写入任务进度失败 任务类型为文档解析
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: req.Id,
TaskType: task.TaskTypeDocParse,
Status: task.TaskStatusFailed,
Remark: "更新文档状态失败: " + err.Error(),
})
return
}
costTime := time.Since(startTime).Milliseconds()
user, err := utils.GetUserInfo(ctx)
if err != nil {
return err
}
// 使用带超时的background context避免HTTP请求完成后context被取消
taskCtx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
taskCtx = context.WithValue(taskCtx, "user", user)
// 任务1: 语义 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.VectorSemanticSplit(ctx, &dto.VectorSemanticSplitReq{Id: req.Id}); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
return &dto.ProcessDocumentRes{
ChunkCount: chunks,
CostTime: costTime,
}, nil
// 任务2: 递归 切分文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.SearchRecursiveSplit(ctx, &dto.SearchRecursiveSplitReq{Id: req.Id}); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
// 任务3: 提取文档
grpool.Add(taskCtx, func(ctx context.Context) {
g.TryCatch(ctx, func(ctx context.Context) {
if innerErr := s.KeywordExtract(ctx, &dto.KeywordExtractReq{Id: req.Id}); innerErr != nil {
cancel()
}
}, func(ctx context.Context, err error) {
cancel()
})
})
return nil
}
// extractDocument 关键词提取(支持取消)
func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
var words []gse.Keyword
var words []utils.Keyword
if len(docs[0].Content) < 500 {
words = gse.GseTool.Extract(docs[0].Content, 4)
words = utils.GseTool.Extract(docs[0].Content, 4)
} else if len(docs[0].Content) < 2000 {
words = gse.GseTool.Extract(docs[0].Content, 8)
words = utils.GseTool.Extract(docs[0].Content, 8)
} else if len(docs[0].Content) < 5000 {
words = gse.GseTool.Extract(docs[0].Content, 13)
words = utils.GseTool.Extract(docs[0].Content, 13)
} else {
var docsSplit []*schema.Document
docsSplit, err = eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "递归分割文档失败: " + err.Error(),
})
return
}
// ========== 取消检查 2循环内部 ==========
for _, t := range docsSplit {
words = append(words, gse.GseTool.Extract(t.Content, 6)...)
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
words = append(words, utils.GseTool.Extract(t.Content, 6)...)
}
}
// ========== 取消检查 3批量操作前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
var keywordReqs = make([]*dto.CreateKeywordReq, 0)
for _, word := range words {
keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{
DatasetId: doc.DatasetId,
DocumentId: doc.Id,
Word: word.Word,
Weight: gconv.Int16(word.Score),
DatasetId: doc.DatasetId,
DocumentId: doc.Id,
Word: word.Word,
Weight: gconv.Int16(word.Score),
KeywordType: keyword.KeywordTypeInitial.Code(),
})
}
if len(keywordReqs) > 0 {
_, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs)
if err != nil {
// 写入任务进度失败 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusFailed,
Remark: "关键字存储失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "关键字提取完成",
})
} else {
// 写入任务进度已完成 任务类型为关键字存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeExtractKeywords,
Status: task.TaskStatusCompleted,
Remark: "没有提取到关键词,关键字提取完成",
})
}
return
}
func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) {
// semanticSplitDocument 语义切分
func (s *documentService) semanticSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 语义切分文件
docsSplit, err := eino.SemanticSplitDocument(ctx, docs)
docsSplit, err := eino.SemanticSplitDocument(ctx, docs, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
docsSplitCount = gconv.Int64(len(docsSplit))
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档
var docsChunk = make([]*schema.Document, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash)
var isNew, needCopy bool
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashSqlKey, contentHash, doc.Id)
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
if !isNew && !needCopy {
continue
}
var metaData = make(map[string]any)
metaData[entity.DocumentCol.TenantId] = doc.TenantId
metaData[entity.DocumentCol.Creator] = doc.Creator
metaData[entity.DocumentCol.DatasetId] = doc.DatasetId
metaData[entity.DocumentChunkCol.DocumentId] = doc.Id
metaData[entity.DocumentChunkCol.ContentHash] = contentHash
metaData[entity.DocumentChunkCol.ChunkIndex] = gconv.Int64(i)
metaData[entity.DocumentVectorCol.DocumentId] = doc.Id
metaData[entity.DocumentVectorCol.ContentHash] = contentHash
metaData[entity.DocumentVectorCol.ChunkIndex] = gconv.Int64(i + 1)
if isNew {
metaData["isNew"] = true
}
if needCopy {
metaData["isNew"] = false
}
t.MetaData = metaData
docsChunk = append(docsChunk, t)
}
// ========== 取消检查 3批量发送前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 发送消息到队列
if len(docsChunk) > 0 {
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
err = gmq.GetGmq(public.GmqMsgPluginsName).GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentChunkTopic,
Topic: public.KnowledgeDocumentVectorTopic,
Data: docsChunk,
},
})
if err != nil {
// 写入任务进度失败 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: "发送消息到队列失败: " + err.Error(),
})
return
}
// 写入任务进度进行中 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusRunning,
Remark: "向量生成任务已提交到队列",
})
} else {
// 写入任务进度已完成 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "无需生成向量,任务完成",
})
}
vectorDocsCount = gconv.Int64(len(docsChunk))
return
}
func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// recursiveSplitDocument 递归切分
func (s *documentService) recursiveSplitDocument(ctx context.Context, doc *entity.Document) (err error) {
// ========== 取消检查 1方法入口 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 1. 加载文件
docs, err := s.loadDocument(ctx, doc)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "加载文件失败: " + err.Error(),
})
return
}
// 2. 递归切分文件
docsSplit, err := eino.RecursiveSplitDocument(ctx, docs)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "文档切分失败: " + err.Error(),
})
return
}
// 2. 获取历史数据
err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "获取历史数据失败: " + err.Error(),
})
return
}
// 3. 组装向量文档并同时构建meilisearch文档
var meiliDocs = make([]interface{}, 0)
for i, t := range docsSplit {
// ========== 取消检查 2循环内部 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
contentHash := gmd5.MustEncryptString(t.Content)
// 检查是否重复
var success bool
success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash)
var isNew, needCopy bool
isNew, needCopy, err = s.checkRepeatWithDocId(ctx, public.KnowledgeContentHashEsKey, contentHash, doc.Id)
if err != nil {
// 写入任务进度失败 任务类型为es存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "检查重复数据失败: " + err.Error(),
})
return
}
if !success {
if !isNew && !needCopy {
continue
}
// 构建Meilisearch文档
meiliDocs = append(meiliDocs, map[string]interface{}{
entity.DocumentChunkCol.Id: contentHash,
entity.DocumentChunkCol.DatasetId: doc.DatasetId,
entity.DocumentChunkCol.DocumentId: doc.Id,
entity.DocumentChunkCol.Content: t.Content,
entity.DocumentChunkCol.ContentHash: contentHash,
entity.DocumentChunkCol.ChunkIndex: i,
entity.DocumentVectorCol.Id: contentHash,
entity.DocumentVectorCol.DatasetId: doc.DatasetId,
entity.DocumentVectorCol.DocumentId: doc.Id,
entity.DocumentVectorCol.Content: t.Content,
entity.DocumentVectorCol.ContentHash: contentHash,
entity.DocumentVectorCol.ChunkIndex: i + 1,
})
}
// ========== 取消检查 3批量写入前 ==========
if ctx.Err() != nil {
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "ctx取消: " + ctx.Err().Error(),
})
return ctx.Err()
}
// 4. 写入到meilisearch数据库中
if len(meiliDocs) > 0 {
if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil {
g.Log().Errorf(ctx, "写入meilisearch失败: %v", err)
// 写入任务进度失败 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusFailed,
Remark: "写入meilisearch失败: " + err.Error(),
})
return
}
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "全文检索数据写入完成",
})
} else {
// 写入任务进度已完成 任务类型为meilisearch存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: doc.Id,
TaskType: task.TaskTypeFullTextSearch,
Status: task.TaskStatusCompleted,
Remark: "无需生成全文检索数据,任务完成",
})
}
return
}
@@ -368,7 +714,8 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
}
// 3. Redis 无数据:根据 contentKey 类型选择查询方式
var dictData = make([]*dto.DocumentChunkRPC, 0)
var dictData = make([]*dto.DocumentVectorRPC, 0)
if public.KnowledgeContentHashSqlKey == contentKey {
// SQL 方式:调用 HTTP 接口查询
dictData, err = s.getHistoryDataFromHttp(ctx, doc)
@@ -380,20 +727,16 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
return err
}
// 4. 把查询到的数据写入 Redis600s过期
for _, item := range dictData {
// 去除可能的 JSON 引号
contentHash := strings.Trim(item.ContentHash, `"`)
key := fmt.Sprintf(contentKey, contentHash)
_, err = g.Redis().Set(ctx, key, true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
})
// SAdd把文档ID加入集合自动去重可存多个
_, err = g.Redis().SAdd(ctx, key, item.DocumentId)
if err != nil {
return err
}
// 设置过期时间
_, _ = g.Redis().Expire(ctx, key, 600)
}
return nil
@@ -405,29 +748,20 @@ func (s *documentService) getHistoryData(ctx context.Context, doc *entity.Docume
}
// getHistoryDataFromHttp 通过 HTTP 接口查询历史数据
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
headers := make(map[string]string)
if r := g.RequestFromCtx(ctx); r != nil {
for k, v := range r.Request.Header {
if len(v) > 0 {
headers[k] = v[0]
}
}
}
func (s *documentService) getHistoryDataFromHttp(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
// 调用接口获取数据
d := &dto.ListDocumentChunkRPC{}
if err = http.Get(ctx, "rag-vector/document/chunk/listDocumentChunk", headers, &d,
"datasetId", gconv.String(doc.DatasetId),
"status", 1); err != nil {
res, _, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
DatasetId: doc.DatasetId,
})
if err != nil {
return
}
dictData = d.List
err = gconv.Struct(res, &dictData)
return
}
// getHistoryDataFromMeilisearch 通过 meilisearch 查询历史数据
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentChunkRPC, err error) {
func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc *entity.Document) (dictData []*dto.DocumentVectorRPC, err error) {
// 构建 meilisearch 查询参数
searchParams := &meilisearch.SearchParams{
Filter: fmt.Sprintf("datasetId = %d", doc.DatasetId),
@@ -442,9 +776,9 @@ func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc
}
// 转换查询结果
dictData = make([]*dto.DocumentChunkRPC, 0)
dictData = make([]*dto.DocumentVectorRPC, 0)
for _, hit := range hits {
item := &dto.DocumentChunkRPC{}
item := &dto.DocumentVectorRPC{}
if err = gconv.Struct(hit, item); err != nil {
return
}
@@ -453,34 +787,39 @@ func (s *documentService) getHistoryDataFromMeilisearch(ctx context.Context, doc
return
}
// checkRepeat 检查是否重复
func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHash string) (success bool, err error) {
var val *gvar.Var
if val, err = g.Redis().Set(ctx, fmt.Sprintf(contentKey, contentHash), true, gredis.SetOption{
TTLOption: gredis.TTLOption{
EX: gconv.PtrInt64(600),
},
NX: true,
}); err != nil {
return
}
success = val.Bool()
return
}
// checkRepeatWithDocId 正确版:检查当前文档是否已存在该分片
// 返回isNew(是否需要生成向量)、isCrossDoc(是否跨文档需拷贝)、err
func (s *documentService) checkRepeatWithDocId(ctx context.Context, contentKey string, contentHash string, currentDocId int64) (isNew bool, needCopy bool, err error) {
key := fmt.Sprintf(contentKey, contentHash)
func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) {
var req = new(dto.KnowledgeDocumentMsg)
if err = gconv.Struct(msg, &req); err != nil {
g.Log().Error(ctx, "DocsVectorStatusMsg err:", err)
return
// 1. 检查当前文档ID是否在集合中
exists, err := g.Redis().SIsMember(ctx, key, currentDocId)
if err != nil {
return false, false, err
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: req.TenantId,
UserName: req.Creator,
})
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
Id: req.Id,
VectorStatus: req.VectorStatus,
})
return
// 情况1当前文档已存在 → 完全跳过,不生成、不拷贝
if !g.IsEmpty(exists) {
return false, false, nil
}
// 2. 检查 key 是否存在(是否有任何文档拥有该分片)
keyExists, err := g.Redis().Exists(ctx, key)
if err != nil {
return false, false, err
}
// 情况2key 不存在 = 全新数据 → 需要生成向量
if g.IsEmpty(keyExists) {
// 把当前文档ID加入集合
_, err = g.Redis().SAdd(ctx, key, currentDocId)
_, _ = g.Redis().Expire(ctx, key, 600)
return true, false, err
}
// 情况3key 存在,但当前文档不在集合中 = 跨文档重复 → 不生成,需拷贝
// 把当前文档ID加入集合记录归属关系
_, err = g.Redis().SAdd(ctx, key, currentDocId)
_, _ = g.Redis().Expire(ctx, key, 600)
return false, true, err
}

View File

@@ -1,91 +0,0 @@
package service
import (
"context"
"rag/common/eino"
"rag/consts/document"
"rag/consts/public"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
gmq "github.com/bjang03/gmq/core/gmq"
"github.com/bjang03/gmq/mq"
"github.com/bjang03/gmq/types"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var DocumentChunk = new(documentChunkService)
type documentChunkService struct{}
const (
DatasetIndexStatusReady = "ready"
)
// Update 更新文件块
func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) {
_, err = dao.DocumentChunk.Update(ctx, req)
return
}
// List 获取文件块列表
func (s *documentChunkService) List(ctx context.Context, req *dto.ListDocumentChunkReq) (res *dto.ListDocumentChunkRes, err error) {
list, total, err := dao.DocumentChunk.List(ctx, req)
if err != nil {
return
}
res = &dto.ListDocumentChunkRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
var docs = make([]*schema.Document, 0)
msgMap := gconv.Map(msg)
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
g.Log().Error(ctx, "DocsChunkMsg err:", err)
return
}
if len(docs) == 0 {
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope))
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
return
}
tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId])
creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator])
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId])
err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code())
return
}
// publishKnowledgeDocumentMsg 发布消息
func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) {
knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{
TenantId: tenantId,
Creator: creator,
Id: documentId,
VectorStatus: vectorStatus,
}
err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{
PubMessage: types.PubMessage{
Topic: public.KnowledgeDocumentVectorStatusTopic,
Data: knowledgeDocumentMsg,
},
})
return
}

221
service/document_vector.go Normal file
View File

@@ -0,0 +1,221 @@
package service
import (
"context"
"fmt"
"rag/common/eino"
"rag/consts/model"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.redpowerfuture.com/red-future/common/beans"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
"github.com/pgvector/pgvector-go"
)
var DocumentVector = new(documentVectorService)
type documentVectorService struct{}
// Query 执行RAG查询
func (s *documentVectorService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) {
modelInfo, err := dao.Model.Get(ctx, &dto.GetModelReq{
ModelType: model.ModelTypeChat.Code(),
})
if err != nil {
g.Log().Errorf(ctx, "获取模型失败: %v", err)
return nil, fmt.Errorf("获取模型失败: %w", err)
}
if modelInfo == nil {
g.Log().Errorf(ctx, "模型不存在: %v", model.ModelTypeChat.Code())
return nil, fmt.Errorf("模型不存在: %w", err)
}
// 4. 使用向量检索器进行查询
r, err := eino.NewPGVectorRetriever(ctx, &eino.PGVectorRetrieverConfig{
DefaultTopK: req.TopK,
}, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
if err != nil {
g.Log().Errorf(ctx, "初始化向量检索器失败: %v", err)
return nil, fmt.Errorf("初始化向量检索器失败: %w", err)
}
// 5. 执行向量检索
docs, err := r.Retrieve(ctx, req.Content, retriever.WithDSLInfo(map[string]any{
"dataset_ids": req.DatasetIds,
"document_ids": req.DocumentIds,
}))
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
messages := make([]*schema.Message, 0)
err = gconv.Struct(req.History, &messages)
if err != nil {
g.Log().Errorf(ctx, "转换历史消息失败: %v", err)
return nil, fmt.Errorf("转换历史消息失败: %w", err)
}
replyMsg, err := eino.NewChatModel(ctx, req.Content, docs, messages, modelInfo.ConfigType)
if err != nil {
g.Log().Errorf(ctx, "向量检索失败: %v", err)
return nil, fmt.Errorf("向量检索失败: %w", err)
}
return &dto.RAGQueryRes{
Answer: replyMsg.Content,
}, nil
}
// Update 更新文件块
func (s *documentVectorService) Update(ctx context.Context, req *dto.UpdateDocumentVectorReq) (err error) {
_, err = dao.DocumentVector.Update(ctx, req)
return
}
// List 获取文件块列表
func (s *documentVectorService) List(ctx context.Context, req *dto.ListDocumentVectorReq) (res *dto.ListDocumentVectorRes, err error) {
list, total, err := dao.DocumentVector.List(ctx, req)
if err != nil {
return
}
res = &dto.ListDocumentVectorRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}
func (s *documentVectorService) DocsChunkMsg(ctx context.Context, msg any) (err error) {
var docs = make([]*schema.Document, 0)
msgMap := gconv.Map(msg)
if err = gconv.Structs(msgMap["data"], &docs); err != nil {
g.Log().Error(ctx, "DocsChunkMsg err:", err)
return
}
if len(docs) == 0 {
g.Log().Error(ctx, "DocsChunkMsg err:", "msg is empty")
return
}
ctx = context.WithValue(ctx, "user", &beans.User{
TenantId: gconv.Uint64(docs[0].MetaData[entity.DocumentVectorCol.TenantId]),
UserName: gconv.String(docs[0].MetaData[entity.DocumentVectorCol.Creator]),
})
documentId := gconv.Int64(docs[0].MetaData[entity.DocumentVectorCol.DocumentId])
var docsStore = make([]*schema.Document, 0)
var docsInsert = make([]*dto.VectorDocumentVectorMsg, 0)
for _, doc := range docs {
if gconv.Bool(doc.MetaData["isNew"]) {
docsStore = append(docsStore, doc)
} else {
ck := new(dto.VectorDocumentVectorMsg)
err = gconv.Struct(doc.MetaData, ck)
ck.Content = doc.Content
ck.VectorStatus = gconv.PtrInt8(1)
ck.Status = gconv.PtrInt8(1)
docsInsert = append(docsInsert, ck)
}
}
if !g.IsEmpty(docsStore) {
idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{
BatchSize: 10,
})
var rows int64
rows, err = idx.Store(ctx, docsStore, model.ModelConfigTypeVectorDashScope.Code()) //TODO 后续替换成本地模型
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
}
if !g.IsEmpty(docsInsert) {
// 1. 提取所有 contentHash
contentHashs := make([]string, 0, len(docsInsert))
for _, d := range docsInsert {
contentHashs = append(contentHashs, d.ContentHash)
}
// 2. 分页查询已存在的向量一页1000避免大查询
var existVectors []*entity.DocumentVector
for page := 1; ; page++ {
res, total, err := dao.DocumentVector.List(ctx, &dto.ListDocumentVectorReq{
Page: &beans.Page{PageSize: 1000, PageNum: int64(page)},
ContentHashs: contentHashs,
})
if err != nil {
return err
}
if len(res) == 0 {
break
}
existVectors = append(existVectors, res...)
if len(existVectors) >= total {
break
}
}
// 3. 构建哈希 -> 向量 的映射表O(1) 查找,性能提升巨大)
vectorMap := make(map[string]pgvector.Vector, len(existVectors))
for _, v := range existVectors {
vectorMap[v.ContentHash] = v.Vector
}
// 4. 回填向量 + 过滤掉数据库已存在的数据(避免重复插入)
for _, d := range docsInsert {
// 回填已有向量
if vec, ok := vectorMap[d.ContentHash]; ok {
d.Vector = vec
}
}
var rows int64
rows, err = dao.DocumentVector.BatchInsert(ctx, docsInsert)
if err != nil || rows == 0 {
g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err)
// 写入任务进度失败 任务类型为sql存储
remark := " 向量存储数量: " + gconv.String(rows)
if err != nil {
remark = "向量存储失败: " + err.Error()
}
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusFailed,
Remark: remark,
})
return
}
}
// 写入任务进度成功 任务类型为sql存储
err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{
TaskId: documentId,
TaskType: task.TaskTypeGenerateVector,
Status: task.TaskStatusCompleted,
Remark: "向量生成完成",
})
return
}

299
service/model.go Normal file
View File

@@ -0,0 +1,299 @@
package service
import (
"context"
"rag/common/eino"
"rag/consts/model"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var ModelService = new(modelService)
type modelService struct{}
// GetModelAllEnums 获取模型全量枚举(模型类型 + 配置类型 合并)
func (s *modelService) GetModelAllEnums(ctx context.Context, req *dto.GetModelAllEnumsReq) (res *dto.GetModelEnumRes, err error) {
_, _ = ctx, req
res = new(dto.GetModelEnumRes)
// 获取所有模型类型
modelTypeRes := model.GetAllModelTypeEnums()
var options []dto.ModelEnumOption
for _, mt := range modelTypeRes.Options {
// 构造 modelType
modelTypeStr := gconv.String(mt.Key)
modelType := model.ModelType(gconv.PtrString(modelTypeStr))
// 获取对应配置类型
configRes := model.GetAllModelConfigTypeEnums(modelType)
// 把 configRes.Options 转成目标类型
var configList []dto.ModelKeyValue
err = gconv.Structs(configRes.Options, &configList)
if err != nil {
return
}
options = append(options, dto.ModelEnumOption{
Key: mt.Key,
Value: mt.Value,
ConfigTypes: configList,
})
}
res.Options = options
return
}
func (s *modelService) GetModelConfigFormFields(ctx context.Context, req *dto.GetModelConfigFormFieldsReq) (*dto.GetModelConfigFormFieldsRes, error) {
_ = ctx
fields := make([]map[string]interface{}, 0)
// ===================== 固定基础字段CreateModelReq 前4个=====================
// 1. 模型类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "modelType",
"label": "模型类型",
"type": "text",
"disabled": true,
"required": true,
"value": model.GetModelTypeDescByCode(req.ModelType),
})
var configTypeValue = "未知类型"
if *req.ModelType == *model.ModelTypeVector.Code() {
configTypeValue = model.GetVectorDescByCode(req.ConfigType)
} else if *req.ModelType == *model.ModelTypeChat.Code() {
configTypeValue = model.GetChatDescByCode(req.ConfigType)
}
// 2. 配置类型:固定只读字段
fields = append(fields, map[string]interface{}{
"name": "configType",
"label": "配置类型",
"type": "text",
"disabled": true,
"required": true,
"value": configTypeValue,
})
// 3. 基础信息
fields = append(fields, []map[string]interface{}{
{
"name": "modelName",
"label": "模型名称",
"type": "input",
"required": true,
"placeholder": "例如DeepSeek 对话模型",
},
{
"name": "modelDesc",
"label": "模型描述",
"type": "textarea",
"required": false,
},
}...)
// 4. 通用模型名称字段
fields = append(fields, map[string]interface{}{
"name": "model",
"label": "模型类型",
"type": "input",
"required": true,
"placeholder": "例如deepseek-chat / text-embedding-3-small",
})
// ===================== 动态配置内容 ConfigContent =====================
// 根据模型类型 + 配置类型生成动态字段
switch *req.ModelType {
case *model.ModelTypeChat.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatArkBot.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
case *model.ModelConfigTypeChatClaude.Code():
fields = append(fields, []map[string]interface{}{
{"name": "by_bedrock", "label": "使用 AWS Bedrock", "type": "switch", "default": true},
{"name": "access_key", "label": "Access Key", "type": "input"},
{"name": "secret_access_key", "label": "Secret Access Key", "type": "input"},
{"name": "region", "label": "Region", "type": "input"},
{"name": "api_key", "label": "API Key", "type": "input"},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
case *model.ModelConfigTypeChatDeepSeek.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input", "default": "https://api.deepseek.com"},
}...)
case *model.ModelConfigTypeChatOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeChatOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeChatQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeChatQwen.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
}...)
}
case *model.ModelTypeVector.Code():
switch *req.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "api_type", "label": "API Type", "type": "input"},
}...)
case *model.ModelConfigTypeVectorOllama.Code():
fields = append(fields, map[string]interface{}{"name": "base_url", "label": "Base URL", "type": "input", "required": true, "default": "http://127.0.0.1:11434"})
case *model.ModelConfigTypeVectorOpenAI.Code():
fields = append(fields, []map[string]interface{}{
{"name": "api_key", "label": "API Key", "type": "input", "required": true},
{"name": "by_azure", "label": "使用 Azure", "type": "switch", "default": true},
{"name": "base_url", "label": "Base URL", "type": "input"},
{"name": "api_version", "label": "API Version", "type": "input"},
}...)
case *model.ModelConfigTypeVectorQianfan.Code():
fields = append(fields, []map[string]interface{}{
{"name": "access_key", "label": "Access Key", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
}...)
case *model.ModelConfigTypeVectorTencentCloud.Code():
fields = append(fields, []map[string]interface{}{
{"name": "secret_id", "label": "Secret ID", "type": "input", "required": true},
{"name": "secret_key", "label": "Secret Key", "type": "input", "required": true},
{"name": "region", "label": "Region", "type": "input", "required": true, "default": "ap-beijing"},
}...)
case *model.ModelConfigTypeVectorDashScope.Code():
fields = append(fields, map[string]interface{}{"name": "api_key", "label": "API Key", "type": "input", "required": true})
}
}
return &dto.GetModelConfigFormFieldsRes{
ModelType: req.ModelType,
ConfigType: req.ConfigType,
Fields: fields,
}, nil
}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
count, err := dao.Model.Count(ctx, &dto.GetModelReq{
ModelType: req.ModelType,
})
if err != nil {
return
}
if count > 0 {
err = gerror.New("模型配置已存在")
return
}
var id int64
id, err = dao.Model.Insert(ctx, req)
if err != nil {
return
}
res = &dto.CreateModelRes{Id: id}
err = s.refresh(ctx, id)
return
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) (err error) {
count, err := dao.Task.Count(ctx, &dto.GetTaskReq{
TaskStatus: task.TaskStatusRunning,
})
if err != nil {
return err
}
if !g.IsEmpty(count) {
err = gerror.New("任务正在执行中,模型配置暂时不可修改,请稍后再试")
return
}
var updateCount int64
updateCount, err = dao.Model.Update(ctx, req)
if err != nil {
return
}
if !g.IsEmpty(updateCount) {
err = s.refresh(ctx, req.Id)
if err != nil {
return err
}
}
return
}
func (s *modelService) refresh(ctx context.Context, id int64) (err error) {
var modelDO *entity.Model
modelDO, err = dao.Model.Get(ctx, &dto.GetModelReq{
Id: id,
})
if err != nil {
return err
}
if *modelDO.ModelType == *model.ModelTypeChat.Code() {
if err = eino.RefreshTenantChatModel(ctx, modelDO); err != nil {
return err
}
}
if *modelDO.ModelType == *model.ModelTypeVector.Code() {
if err = eino.RefreshTenantEmbedder(ctx, modelDO); err != nil {
return err
}
}
return
}
func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) (err error) {
_, err = dao.Model.Delete(ctx, req)
return
}
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (res *dto.ModelVO, err error) {
r, err := dao.Model.Get(ctx, req)
err = gconv.Struct(r, &res)
return
}
func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dto.ListModelRes, err error) {
list, total, err := dao.Model.List(ctx, req)
if err != nil {
return nil, err
}
res = &dto.ListModelRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}

148
service/task.go Normal file
View File

@@ -0,0 +1,148 @@
package service
import (
"context"
"rag/consts/document"
"rag/consts/public"
"rag/consts/task"
"rag/dao"
"rag/model/dto"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var Task = new(taskService)
type taskService struct{}
// WriteTaskProgress 写入任务进度(核心方法)
func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) {
t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{
TaskId: req.TaskId,
})
if err != nil {
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
completed := false
if total != 0 {
taskVO := make([]*dto.TaskVO, 0, total)
err = gconv.Struct(t, &taskVO)
if err != nil {
g.Log().Errorf(ctx, "转换任务失败: %v", err)
return err
}
taskVO = append(taskVO, &dto.TaskVO{
TaskType: req.TaskType,
Status: req.Status,
})
completed = IsAllSubTasks(taskVO, task.TaskStatusCompleted)
}
// 1. 查询是否已存在该文档的该类型任务
existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
})
if err != nil {
g.Log().Errorf(ctx, "查询任务失败: %v", err)
return err
}
err = gfdb.DB(ctx, public.DbNameKnowledge).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
// 2. 如果不存在,则创建新任务
if g.IsEmpty(existTask) {
createReq := &dto.CreateTaskReq{
TaskId: req.TaskId,
TaskType: req.TaskType,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Insert(ctx, createReq)
} else {
// 3. 如果已存在,则更新任务
updateReq := &dto.UpdateTaskReq{
Id: existTask[0].Id,
Status: req.Status,
Remark: req.Remark,
}
_, err = dao.Task.Update(ctx, updateReq)
if err != nil {
g.Log().Errorf(ctx, "更新任务失败: %v", err)
return err
}
}
if completed {
// 3. 如果已存在,则更新任务
_, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{
TaskId: req.TaskId,
Status: task.TaskStatusCompleted,
Remark: "文档解析完成",
})
if err != nil {
g.Log().Errorf(ctx, "更新任务失败: %v", err)
return err
}
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
Id: req.TaskId,
VectorStatus: document.VectorStatusCompleted.Code(),
})
if err != nil {
return err
}
} else {
if task.TaskStatusFailed == req.Status {
_, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{
Id: req.TaskId,
VectorStatus: document.VectorStatusFailed.Code(),
})
if err != nil {
return err
}
}
}
return nil
})
return
}
// IsAllSubTasks 判断三个子任务
func IsAllSubTasks(subTasks []*dto.TaskVO, taskStatus task.TaskStatus) bool {
// 必须包含 3 种任务类型
hasKeywords := false
hasVector := false
hasFullText := false
for _, t := range subTasks {
// 子任务必须是【已完成】状态才计数
if t.Status == taskStatus {
switch t.TaskType {
case task.TaskTypeExtractKeywords:
hasKeywords = true
case task.TaskTypeGenerateVector:
hasVector = true
case task.TaskTypeFullTextSearch:
hasFullText = true
}
}
}
// 三个任务全部完成 → 返回true
return hasKeywords && hasVector && hasFullText
}
func (s *taskService) Get(ctx context.Context, req *dto.GetTaskReq) (res *dto.ListTaskRes, err error) {
list, total, err := dao.Task.Get(ctx, req)
if err != nil {
return
}
res = &dto.ListTaskRes{
Total: total,
}
err = gconv.Struct(list, &res.List)
return
}

View File

@@ -114,6 +114,7 @@ COMMENT ON COLUMN rag_knowledge_document.file_path IS '文件存储路径如M
COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}';
--------------------pgsql创建rag_knowledge_document表语句---------------------------
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
-- 关键词表(文档关键词+权重)
CREATE TABLE IF NOT EXISTS rag_knowledge_keyword (
@@ -158,18 +159,104 @@ COMMENT ON COLUMN rag_knowledge_keyword.dataset_id IS '数据集ID';
COMMENT ON COLUMN rag_knowledge_keyword.document_id IS '文档ID';
COMMENT ON COLUMN rag_knowledge_keyword.word IS '关键词';
COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重';
CREATE UNIQUE INDEX uk_rag_knowledge_keyword_tenant_dataset_doc_word ON rag_knowledge_keyword (tenant_id, dataset_id, document_id, word);
--------------------pgsql创建rag_knowledge_keyword表语句---------------------------
--------------------pgsql创建rag_knowledge_task表语句---------------------------
-- 知识库任务表
CREATE TABLE IF NOT EXISTS rag_knowledge_task (
-- 基础字段(完全对齐项目规范)
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
-- 业务字段
task_id BIGINT NOT NULL, -- 任务ID
task_type VARCHAR(32) NOT NULL, -- 任务类型
status VARCHAR(32) NOT NULL, -- 任务状态
executor VARCHAR(128) DEFAULT '', -- 执行器
remark TEXT DEFAULT '' -- 备注
);
-- 索引(高频查询)
CREATE INDEX idx_rkt_tenant_id ON rag_knowledge_task(tenant_id);
CREATE INDEX idx_rkt_task_id ON rag_knowledge_task(task_id);
CREATE INDEX idx_rkt_task_type ON rag_knowledge_task(task_type);
CREATE INDEX idx_rkt_status ON rag_knowledge_task(status);
CREATE INDEX idx_rkt_deleted_at ON rag_knowledge_task(deleted_at);
-- 表和字段注释
COMMENT ON TABLE rag_knowledge_task IS '知识库任务表';
COMMENT ON COLUMN rag_knowledge_task.id IS '主键ID非自增';
COMMENT ON COLUMN rag_knowledge_task.tenant_id IS '租户ID';
COMMENT ON COLUMN rag_knowledge_task.creator IS '创建人';
COMMENT ON COLUMN rag_knowledge_task.created_at IS '创建时间';
COMMENT ON COLUMN rag_knowledge_task.updater IS '更新人';
COMMENT ON COLUMN rag_knowledge_task.updated_at IS '更新时间';
COMMENT ON COLUMN rag_knowledge_task.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN rag_knowledge_task.task_id IS '任务ID';
COMMENT ON COLUMN rag_knowledge_task.task_type IS '任务类型';
COMMENT ON COLUMN rag_knowledge_task.status IS '任务状态';
COMMENT ON COLUMN rag_knowledge_task.executor IS '执行器';
COMMENT ON COLUMN rag_knowledge_task.remark IS '备注';
--------------------pgsql创建rag_knowledge_task表语句---------------------------
--------------------pgsql创建rag_knowledge_model表语句---------------------------
-- 知识库模型配置表
CREATE TABLE IF NOT EXISTS rag_knowledge_model (
-- 基础字段(完全对齐项目规范)
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at timestamp(6),
-- 业务字段
dataset_id BIGINT NOT NULL, -- 数据集ID
model_type VARCHAR(32) NOT NULL, -- 模型类型
model_name VARCHAR(128) NOT NULL, -- 模型名称
model_desc TEXT DEFAULT '', -- 模型描述
model_config JSONB DEFAULT '{}'::JSONB -- 模型配置(JSONB)
);
-- 索引(高频查询)
CREATE INDEX idx_rkm_tenant_id ON rag_knowledge_model(tenant_id);
CREATE INDEX idx_rkm_dataset_id ON rag_knowledge_model(dataset_id);
CREATE INDEX idx_rkm_model_type ON rag_knowledge_model(model_type);
CREATE INDEX idx_rkm_deleted_at ON rag_knowledge_model(deleted_at);
-- 表和字段注释
COMMENT ON TABLE rag_knowledge_model IS '知识库模型配置表';
COMMENT ON COLUMN rag_knowledge_model.id IS '主键ID非自增';
COMMENT ON COLUMN rag_knowledge_model.tenant_id IS '租户ID';
COMMENT ON COLUMN rag_knowledge_model.creator IS '创建人';
COMMENT ON COLUMN rag_knowledge_model.created_at IS '创建时间';
COMMENT ON COLUMN rag_knowledge_model.updater IS '更新人';
COMMENT ON COLUMN rag_knowledge_model.updated_at IS '更新时间';
COMMENT ON COLUMN rag_knowledge_model.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN rag_knowledge_model.dataset_id IS '数据集ID';
COMMENT ON COLUMN rag_knowledge_model.model_type IS '模型类型';
COMMENT ON COLUMN rag_knowledge_model.model_name IS '模型名称';
COMMENT ON COLUMN rag_knowledge_model.model_desc IS '模型描述';
COMMENT ON COLUMN rag_knowledge_model.model_config IS '模型配置(JSONB)';
--------------------pgsql创建rag_knowledge_model表语句---------------------------
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
-- 向量数据集索引表
CREATE TABLE IF NOT EXISTS rag_vector_dataset_index (
-- 基础字段
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -216,16 +303,16 @@ COMMENT ON COLUMN rag_vector_dataset_index.description IS '描述';
--------------------pgsql创建rag_vector_dataset_index表语句---------------------------
--------------------pgsql创建rag_vector_document_chunk表语句---------------------------
--------------------pgsql创建rag_vector_document_vector表语句---------------------------
CREATE EXTENSION IF NOT EXISTS vector;
-- 文档分块向量表
CREATE TABLE IF NOT EXISTS rag_vector_document_chunk (
CREATE TABLE IF NOT EXISTS rag_vector_document_vector (
-- 基础字段
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
id BIGINT PRIMARY KEY, -- 主键ID非自增
tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8
creator VARCHAR(64) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updater VARCHAR(64) NOT NULL,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -248,30 +335,30 @@ CREATE TABLE IF NOT EXISTS rag_vector_document_chunk (
);
-- 索引
CREATE INDEX idx_chunk_tenant_id ON rag_vector_document_chunk(tenant_id);
CREATE INDEX idx_chunk_dataset_id ON rag_vector_document_chunk(dataset_id);
CREATE INDEX idx_chunk_document_id ON rag_vector_document_chunk(document_id);
CREATE INDEX idx_chunk_content_hash ON rag_vector_document_chunk(content_hash);
CREATE INDEX idx_chunk_status ON rag_vector_document_chunk(status);
CREATE INDEX idx_chunk_vector_status ON rag_vector_document_chunk(vector_status);
CREATE INDEX idx_vector_tenant_id ON rag_vector_document_vector(tenant_id);
CREATE INDEX idx_vector_dataset_id ON rag_vector_document_vector(dataset_id);
CREATE INDEX idx_vector_document_id ON rag_vector_document_vector(document_id);
CREATE INDEX idx_vector_content_hash ON rag_vector_document_vector(content_hash);
CREATE INDEX idx_vector_status ON rag_vector_document_vector(status);
CREATE INDEX idx_vector_vector_status ON rag_vector_document_vector(vector_status);
-- 注释
COMMENT ON TABLE rag_vector_document_chunk IS '文档分块向量表';
COMMENT ON COLUMN rag_vector_document_chunk.id IS '主键ID非自增';
COMMENT ON COLUMN rag_vector_document_chunk.tenant_id IS '租户ID';
COMMENT ON COLUMN rag_vector_document_chunk.creator IS '创建人';
COMMENT ON COLUMN rag_vector_document_chunk.created_at IS '创建时间';
COMMENT ON COLUMN rag_vector_document_chunk.updater IS '更新人';
COMMENT ON COLUMN rag_vector_document_chunk.updated_at IS '更新时间';
COMMENT ON COLUMN rag_vector_document_chunk.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN rag_vector_document_chunk.status IS '状态';
COMMENT ON COLUMN rag_vector_document_chunk.vector_status IS '向量生成状态';
COMMENT ON COLUMN rag_vector_document_chunk.dataset_id IS '数据集ID';
COMMENT ON COLUMN rag_vector_document_chunk.document_id IS '文档ID';
COMMENT ON COLUMN rag_vector_document_chunk.content IS '分块内容';
COMMENT ON COLUMN rag_vector_document_chunk.content_hash IS '内容哈希';
COMMENT ON COLUMN rag_vector_document_chunk.chunk_index IS '分块序号';
COMMENT ON COLUMN rag_vector_document_chunk.vector IS '向量数据';
COMMENT ON COLUMN rag_vector_document_chunk.metadata IS '扩展元数据';
COMMENT ON TABLE rag_vector_document_vector IS '文档分块向量表';
COMMENT ON COLUMN rag_vector_document_vector.id IS '主键ID非自增';
COMMENT ON COLUMN rag_vector_document_vector.tenant_id IS '租户ID';
COMMENT ON COLUMN rag_vector_document_vector.creator IS '创建人';
COMMENT ON COLUMN rag_vector_document_vector.created_at IS '创建时间';
COMMENT ON COLUMN rag_vector_document_vector.updater IS '更新人';
COMMENT ON COLUMN rag_vector_document_vector.updated_at IS '更新时间';
COMMENT ON COLUMN rag_vector_document_vector.deleted_at IS '删除时间(软删)';
COMMENT ON COLUMN rag_vector_document_vector.status IS '状态';
COMMENT ON COLUMN rag_vector_document_vector.vector_status IS '向量生成状态';
COMMENT ON COLUMN rag_vector_document_vector.dataset_id IS '数据集ID';
COMMENT ON COLUMN rag_vector_document_vector.document_id IS '文档ID';
COMMENT ON COLUMN rag_vector_document_vector.content IS '分块内容';
COMMENT ON COLUMN rag_vector_document_vector.content_hash IS '内容哈希';
COMMENT ON COLUMN rag_vector_document_vector.chunk_index IS '分块序号';
COMMENT ON COLUMN rag_vector_document_vector.vector IS '向量数据';
COMMENT ON COLUMN rag_vector_document_vector.metadata IS '扩展元数据';
--------------------pgsql创建rag_vector_document_chunk表语句---------------------------
--------------------pgsql创建rag_vector_document_vector表语句---------------------------