diff --git a/config/constants.go b/config/constants.go new file mode 100644 index 0000000..0cf2867 --- /dev/null +++ b/config/constants.go @@ -0,0 +1,34 @@ +package config + +// ==================== 可配置常量 ==================== +// 修改以下值来调整系统行为 + +// -------------------- 追问配置 -------------------- + +// FollowUpDelay1 第一次追问延时(秒) +var FollowUpDelay1 = 30 + +// FollowUpDelay2 第二次追问延时(秒) +var FollowUpDelay2 = 60 + +// FollowUpDelay3 第三次追问延时(秒) +var FollowUpDelay3 = 180 + +// FollowUpContent1 第一次追问话术 +var FollowUpContent1 = "还有其他问题吗?" + +// FollowUpContent2 第二次追问话术 +var FollowUpContent2 = "如果需要帮助,随时告诉我~" + +// FollowUpContent3 第三次追问话术 +var FollowUpContent3 = "我一直在线,有问题随时找我~" + +// -------------------- 归档配置 -------------------- + +// ArchiveDelay 归档延时(秒),默认 1 小时 +var ArchiveDelay = 3600 + +// -------------------- 历史上下文配置 -------------------- + +// HistoryContextLimit 读取历史对话轮数(用于新 Session 上下文注入) +var HistoryContextLimit int64 = 5 diff --git a/rabbitmq/publisher.go b/rabbitmq/publisher.go index 182622b..d9753b3 100644 --- a/rabbitmq/publisher.go +++ b/rabbitmq/publisher.go @@ -23,8 +23,13 @@ func NewPublisher(exchange, routingKey string) *Publisher { } } -// Publish 发布消息 +// Publish 发布消息(使用默认 routing key) func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error) { + return p.PublishWithRoutingKey(ctx, p.routingKey, message) +} + +// PublishWithRoutingKey 发布消息(指定 routing key) +func (p *Publisher) PublishWithRoutingKey(ctx context.Context, routingKey string, message interface{}) (err error) { ch, err := GetChannel() if err != nil { return err @@ -39,10 +44,10 @@ func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error // 发布消息 err = ch.PublishWithContext( ctx, - p.exchange, // exchange - p.routingKey, // routing key - false, // mandatory - false, // immediate + p.exchange, // exchange + routingKey, // routing key + false, // mandatory + false, // immediate amqp.Publishing{ DeliveryMode: amqp.Persistent, // 持久化 ContentType: "application/json", @@ -52,12 +57,12 @@ func (p *Publisher) Publish(ctx context.Context, message interface{}) (err error if err != nil { g.Log().Errorf(ctx, "发布消息失败: exchange=%s, routingKey=%s, err=%v", - p.exchange, p.routingKey, err) + p.exchange, routingKey, err) return err } g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s", - p.exchange, p.routingKey) + p.exchange, routingKey) return } diff --git a/ragflow/client.go b/ragflow/client.go index 6b4ccc5..618d84b 100644 --- a/ragflow/client.go +++ b/ragflow/client.go @@ -5,6 +5,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/errors/gerror" @@ -34,6 +35,7 @@ func init() { httpClient := gclient.New() httpClient.SetHeader("Authorization", "Bearer "+apiKey) httpClient.SetHeader("Content-Type", "application/json") + httpClient.SetTimeout(60 * time.Second) // RAGFlow AI 推理需要较长时间 globalClient = &Client{ BaseURL: strings.TrimSuffix(baseURL, "/"), diff --git a/ragflow/worker_pool.go b/ragflow/worker_pool.go index b836ef8..06f7878 100644 --- a/ragflow/worker_pool.go +++ b/ragflow/worker_pool.go @@ -2,58 +2,28 @@ package ragflow import ( "context" - "sync" "gitee.com/red-future---jilin-g/common/redis" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/os/grpool" ) -// WorkerPool RAGFlow 请求处理协程池 +// 默认协程池大小 +const defaultPoolSize = 200 + +// workerPool 协程池单例(grpool.New 是原型模式,需要变量引用) +var workerPool = grpool.New(defaultPoolSize) + +// WorkerPool RAGFlow 请求处理协程池(封装 grpool) type WorkerPool struct { pool *grpool.Pool size int } -// 单例模式相关变量 -var ( - workerPoolInstance *WorkerPool - workerPoolOnce sync.Once -) - -// GetWorkerPoolWithSize 获取指定大小的协程池单例 -// 使用 sync.Once 确保只创建一次,size 仅首次调用生效 -func GetWorkerPoolWithSize(size int) *WorkerPool { - workerPoolOnce.Do(func() { - if size <= 0 { - size = 200 // 默认大小 - } - workerPoolInstance = &WorkerPool{ - pool: grpool.New(size), - size: size, - } - }) - return workerPoolInstance -} - -// GetWorkerPool 获取协程池单例(使用默认大小 200) -func GetWorkerPool() *WorkerPool { - return GetWorkerPoolWithSize(200) -} - -// NewWorkerPool 创建协程池(兼容旧代码,内部使用单例) -// 参数: -// - size: 协程池大小,仅首次调用生效 -// -// 返回: -// - *WorkerPool: 协程池单例实例 -// - error: 创建失败时返回错误 -func NewWorkerPool(size int) (*WorkerPool, error) { - if size <= 0 { - return nil, gerror.New("协程池大小必须大于0") - } - return GetWorkerPoolWithSize(size), nil +// Pool 协程池单例实例(直接引用使用) +var Pool = &WorkerPool{ + pool: workerPool, + size: defaultPoolSize, } // Submit 提交任务到协程池 diff --git a/redis/redis.go b/redis/redis.go index 186d2ad..f90f3b1 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -3,38 +3,18 @@ package redis import ( "context" "strings" - "sync" - "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" ) -var ( - // redisClient 单例 Redis 客户端 - redisClient *gredis.Redis - // redisOnce 确保只初始化一次 - redisOnce sync.Once - // RedisClient 兼容导出(供 mongo.go 使用) - // 注意:这是一个指向单例的指针,首次调用 GetRedisClient() 后生效 - RedisClient *gredis.Redis -) +// redisClient 内部使用的 Redis 客户端(g.Redis() 是原型模式,需要变量引用) +var redisClient = g.Redis() -// GetRedisClient 获取 Redis 客户端(单例模式) -func GetRedisClient() *gredis.Redis { - redisOnce.Do(func() { - redisClient = g.Redis() - RedisClient = redisClient // 同步更新兼容导出 - }) - return redisClient -} - -// init 包初始化时自动初始化 Redis 客户端 -func init() { - GetRedisClient() -} +// RedisClient 导出的 Redis 客户端(供 mongo.go 使用) +var RedisClient = redisClient // Stream 和消费者组常量 const ( @@ -62,7 +42,7 @@ type StreamMessage struct { // 使用 gredis Do() 方法执行 XGROUP CREATE 命令 func InitStreamGroup(ctx context.Context, streamKey, groupName string) error { // XGROUP CREATE streamKey groupName 0 MKSTREAM - _, err := GetRedisClient().Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM") + _, err := redisClient.Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM") if err != nil { // 如果组已存在,忽略错误 errStr := err.Error() @@ -76,21 +56,25 @@ func InitStreamGroup(ctx context.Context, streamKey, groupName string) error { // AddToStream 将消息添加到 Stream // 使用 gredis Do() 方法执行 XADD 命令 -func AddToStream(ctx context.Context, streamKey string, values map[string]interface{}) (string, error) { +// msg 可以是结构体或 map,内部自动转换 +func AddToStream(ctx context.Context, streamKey string, msg interface{}) (messageID string, err error) { + // 将结构体转换为 map + values := gconv.Map(msg) + // XADD streamKey * field1 value1 field2 value2 ... - args := []interface{}{streamKey, "*"} // "*" 自动生成ID + args := make([]interface{}, 0, len(values)*2+2) + args = append(args, streamKey, "*") // "*" 自动生成ID for key, val := range values { args = append(args, key, val) } - result, err := GetRedisClient().Do(ctx, "XADD", args...) + result, err := redisClient.Do(ctx, "XADD", args...) if err != nil { - return "", err + return } - // 返回消息ID - messageID := result.String() - return messageID, nil + messageID = result.String() + return } // ReadFromStream 从 Stream 读取消息(消费者组模式) @@ -100,7 +84,7 @@ func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName stri groupName, consumerName, count, blockMs, streamKey) // XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockMs STREAMS streamKey > - result, err := GetRedisClient().Do(ctx, + result, err := redisClient.Do(ctx, "XREADGROUP", "GROUP", groupName, consumerName, "COUNT", count, "BLOCK", blockMs, @@ -208,7 +192,7 @@ func AckMessage(ctx context.Context, streamKey, groupName string, messageIDs ... args = append(args, id) } - _, err := GetRedisClient().Do(ctx, "XACK", args...) + _, err := redisClient.Do(ctx, "XACK", args...) return err } @@ -216,7 +200,7 @@ func AckMessage(ctx context.Context, streamKey, groupName string, messageIDs ... // 使用 gredis Do() 方法执行 XLEN 命令 func GetStreamLength(ctx context.Context, streamKey string) (int64, error) { // XLEN streamKey - result, err := GetRedisClient().Do(ctx, "XLEN", streamKey) + result, err := redisClient.Do(ctx, "XLEN", streamKey) if err != nil { return 0, err } @@ -237,7 +221,7 @@ type PendingMessage struct { // 使用 gredis Do() 方法执行 XPENDING 命令 func GetPendingMessages(ctx context.Context, streamKey, groupName string, start, end string, count int64) ([]PendingMessage, error) { // XPENDING streamKey groupName start end count - result, err := GetRedisClient().Do(ctx, "XPENDING", streamKey, groupName, start, end, count) + result, err := redisClient.Do(ctx, "XPENDING", streamKey, groupName, start, end, count) if err != nil { return nil, err } @@ -279,7 +263,7 @@ func ClaimPendingMessage(ctx context.Context, streamKey, groupName, consumerName args = append(args, id) } - result, err := GetRedisClient().Do(ctx, "XCLAIM", args...) + result, err := redisClient.Do(ctx, "XCLAIM", args...) if err != nil { return nil, err } @@ -333,7 +317,7 @@ func SetSessionLastActive(ctx context.Context, userId string) error { timestamp := gtime.Now().Timestamp() // SETEX key 7200 value (7200秒 = 2小时) - _, err := GetRedisClient().Do(ctx, "SETEX", key, 7200, timestamp) + _, err := redisClient.Do(ctx, "SETEX", key, 7200, timestamp) return err } @@ -341,7 +325,7 @@ func SetSessionLastActive(ctx context.Context, userId string) error { // 使用 gredis Get 方法 func GetSessionLastActive(ctx context.Context, userId string) (int64, error) { key := SessionLastActiveKeyPrefix + userId + ":last_active" - result, err := GetRedisClient().Get(ctx, key) + result, err := redisClient.Get(ctx, key) if err != nil { return 0, err } @@ -383,7 +367,7 @@ func SetSessionCache(ctx context.Context, userId, sessionId string) error { key := SessionLastActiveKeyPrefix + userId + ":session_id" // SETEX key 604800 value (604800秒 = 7天) - _, err := GetRedisClient().Do(ctx, "SETEX", key, 604800, sessionId) + _, err := redisClient.Do(ctx, "SETEX", key, 604800, sessionId) return err } @@ -397,7 +381,7 @@ const ( // windowSeconds: 时间窗口(秒) func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) { fullKey := RateLimitKeyPrefix + key - result, err := GetRedisClient().Do(ctx, "INCR", fullKey) + result, err := redisClient.Do(ctx, "INCR", fullKey) if err != nil { return } @@ -405,7 +389,7 @@ func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count // 首次设置过期时间 if count == 1 { - GetRedisClient().Do(ctx, "EXPIRE", fullKey, windowSeconds) + redisClient.Do(ctx, "EXPIRE", fullKey, windowSeconds) } return } @@ -413,7 +397,7 @@ func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count // GetRateLimit 获取当前限流计数 func GetRateLimit(ctx context.Context, key string) (count int64, err error) { fullKey := RateLimitKeyPrefix + key - result, err := GetRedisClient().Get(ctx, fullKey) + result, err := redisClient.Get(ctx, fullKey) if err != nil { return } @@ -425,10 +409,9 @@ func GetRateLimit(ctx context.Context, key string) (count int64, err error) { } // GetSessionCache 获取缓存的 RAGFlow Session ID -// 使用 gredis Get 方法 func GetSessionCache(ctx context.Context, userId string) (string, error) { key := SessionLastActiveKeyPrefix + userId + ":session_id" - result, err := GetRedisClient().Get(ctx, key) + result, err := redisClient.Get(ctx, key) if err != nil { return "", err } @@ -439,3 +422,10 @@ func GetSessionCache(ctx context.Context, userId string) (string, error) { return result.String(), nil } + +// DelSessionCache 删除缓存的 RAGFlow Session ID(归档时调用) +func DelSessionCache(ctx context.Context, userId string) error { + key := SessionLastActiveKeyPrefix + userId + ":session_id" + _, err := redisClient.Del(ctx, key) + return err +} diff --git a/redis/types.go b/redis/types.go index ecdce57..59bf280 100644 --- a/redis/types.go +++ b/redis/types.go @@ -1,21 +1,23 @@ package redis +import "gitee.com/red-future---jilin-g/common/config" + +// HistoryMessage 历史消息结构(用于上下文注入) +type HistoryMessage struct { + Question string `json:"question"` // 用户问题 + Answer string `json:"answer"` // AI 回复 +} + // SendStreamMessage 发送到 Redis Stream 的消息结构 type SendStreamMessage struct { - UserId string `json:"user_id"` // 用户ID - Content string `json:"content"` // 消息内容 - Timestamp int64 `json:"timestamp"` // 时间戳(秒) - MessageId string `json:"message_id"` // 消息唯一ID -} - -// ToMap 转换为 map[string]interface{} 用于 Stream 存储 -func (m *SendStreamMessage) ToMap() map[string]interface{} { - return map[string]interface{}{ - "user_id": m.UserId, - "content": m.Content, - "timestamp": m.Timestamp, - "message_id": m.MessageId, - } + UserId string `json:"user_id"` // 用户ID + Content string `json:"content"` // 消息内容 + Timestamp int64 `json:"timestamp"` // 时间戳(秒) + MessageId string `json:"message_id"` // 消息唯一ID + Platform string `json:"platform,omitempty"` // 平台标识 + AccountId string `json:"account_id,omitempty"` // 账号ID + TenantId string `json:"tenant_id,omitempty"` // 租户ID(数据隔离) + History []HistoryMessage `json:"history,omitempty"` // 历史对话(归档后恢复时携带) } // BatchStreamMessage 批量消息结构 @@ -27,21 +29,11 @@ type BatchStreamMessage struct { Index int `json:"index"` // 批次内序号 } -// ToMap 转换为 map[string]interface{} 用于 Stream 存储 -func (m *BatchStreamMessage) ToMap() map[string]interface{} { - return map[string]interface{}{ - "user_id": m.UserId, - "content": m.Content, - "timestamp": m.Timestamp, - "batch_id": m.BatchId, - "index": m.Index, - } -} - -// ResponseStreamMessage RAGFlow 响应消息结构(写入结果 Stream) +// ResponseStreamMessage RAGFlow 响应消息结构(MQ 消息) type ResponseStreamMessage struct { UserId string `json:"user_id"` // 用户ID Platform string `json:"platform"` // 平台标识 + TenantId string `json:"tenant_id"` // 租户ID Question string `json:"question"` // 用户问题 Content string `json:"content"` // RAGFlow 回复内容 SessionId string `json:"session_id"` // RAGFlow Session ID @@ -49,19 +41,6 @@ type ResponseStreamMessage struct { MessageId string `json:"message_id"` // 原始消息ID } -// ToMap 转换为 map[string]interface{} 用于 Stream 存储 -func (m *ResponseStreamMessage) ToMap() map[string]interface{} { - return map[string]interface{}{ - "user_id": m.UserId, - "platform": m.Platform, - "question": m.Question, - "content": m.Content, - "session_id": m.SessionId, - "timestamp": m.Timestamp, - "message_id": m.MessageId, - } -} - // FollowUpMessage 追问消息结构(RabbitMQ 延时队列) type FollowUpMessage struct { UserId string `json:"user_id"` // 用户ID @@ -71,25 +50,39 @@ type FollowUpMessage struct { Timestamp int64 `json:"timestamp"` // 发送时间戳 } -// 追问话术常量 +// 追问类型常量 const ( - FollowUpType1 = 1 // 30秒追问 - FollowUpType2 = 2 // 60秒追问 - FollowUpType3 = 3 // 180秒追问 + FollowUpType1 = 1 // 第一次追问 + FollowUpType2 = 2 // 第二次追问 + FollowUpType3 = 3 // 第三次追问 ) -// 追问话术内容 -var FollowUpContents = map[int]string{ - FollowUpType1: "还有其他问题吗?", - FollowUpType2: "如果需要帮助,随时告诉我~", - FollowUpType3: "我一直在线,有问题随时找我~", +// GetFollowUpContent 获取追问话术(从 config 包读取) +func GetFollowUpContent(followUpType int) string { + switch followUpType { + case FollowUpType1: + return config.FollowUpContent1 + case FollowUpType2: + return config.FollowUpContent2 + case FollowUpType3: + return config.FollowUpContent3 + default: + return "" + } } -// 追问延时时间(秒) -var FollowUpDelays = map[int]int{ - FollowUpType1: 30, - FollowUpType2: 60, - FollowUpType3: 180, +// GetFollowUpDelay 获取追问延时(从 config 包读取) +func GetFollowUpDelay(followUpType int) int { + switch followUpType { + case FollowUpType1: + return config.FollowUpDelay1 + case FollowUpType2: + return config.FollowUpDelay2 + case FollowUpType3: + return config.FollowUpDelay3 + default: + return 0 + } } // ArchiveMessage 会话归档消息结构(RabbitMQ 延时队列) @@ -100,5 +93,12 @@ type ArchiveMessage struct { Timestamp int64 `json:"timestamp"` // 发送时间戳 } -// 归档延时时间(秒) -const ArchiveDelaySeconds = 3600 // 60分钟 +// GetArchiveDelay 获取归档延时(从 config 包读取) +func GetArchiveDelay() int { + return config.ArchiveDelay +} + +// GetHistoryContextLimit 获取历史上下文轮数(从 config 包读取) +func GetHistoryContextLimit() int64 { + return config.HistoryContextLimit +}