diff --git a/rabbitmq/client.go b/rabbitmq/client.go new file mode 100644 index 0000000..7738082 --- /dev/null +++ b/rabbitmq/client.go @@ -0,0 +1,179 @@ +package rabbitmq + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gogf/gf/v2/frame/g" + amqp "github.com/rabbitmq/amqp091-go" +) + +var ( + conn *amqp.Connection + channel *amqp.Channel + once sync.Once + mu sync.RWMutex +) + +// Config RabbitMQ 配置 +type Config struct { + Host string + Port int + Username string + Password string + VHost string +} + +// Init 初始化 RabbitMQ 连接 +func Init(ctx context.Context, cfg *Config) error { + var err error + once.Do(func() { + // 构建连接字符串 + url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s", + cfg.Username, + cfg.Password, + cfg.Host, + cfg.Port, + cfg.VHost, + ) + + // 创建连接 + conn, err = amqp.Dial(url) + if err != nil { + g.Log().Errorf(ctx, "RabbitMQ 连接失败: %v", err) + return + } + + // 创建 Channel + channel, err = conn.Channel() + if err != nil { + g.Log().Errorf(ctx, "创建 RabbitMQ Channel 失败: %v", err) + return + } + + // 监听连接关闭 + go handleConnectionClose(ctx) + + g.Log().Info(ctx, "RabbitMQ 连接成功") + }) + + return err +} + +// InitFromConfig 从配置文件初始化 +func InitFromConfig(ctx context.Context) error { + cfg := &Config{ + Host: g.Cfg().MustGet(ctx, "rabbitmq.host").String(), + Port: g.Cfg().MustGet(ctx, "rabbitmq.port").Int(), + Username: g.Cfg().MustGet(ctx, "rabbitmq.username").String(), + Password: g.Cfg().MustGet(ctx, "rabbitmq.password").String(), + VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(), + } + + return Init(ctx, cfg) +} + +// GetChannel 获取 Channel +func GetChannel() (*amqp.Channel, error) { + mu.RLock() + defer mu.RUnlock() + + if channel == nil || channel.IsClosed() { + return nil, fmt.Errorf("RabbitMQ Channel 未初始化或已关闭") + } + + return channel, nil +} + +// GetConnection 获取连接 +func GetConnection() (*amqp.Connection, error) { + mu.RLock() + defer mu.RUnlock() + + if conn == nil || conn.IsClosed() { + return nil, fmt.Errorf("RabbitMQ 连接未初始化或已关闭") + } + + return conn, nil +} + +// handleConnectionClose 监听连接关闭并重连 +func handleConnectionClose(ctx context.Context) { + closeErr := make(chan *amqp.Error) + conn.NotifyClose(closeErr) + + err := <-closeErr + if err != nil { + g.Log().Errorf(ctx, "RabbitMQ 连接关闭: %v,尝试重连...", err) + reconnect(ctx) + } +} + +// reconnect 重新连接 +func reconnect(ctx context.Context) { + mu.Lock() + defer mu.Unlock() + + for i := 0; i < 10; i++ { + time.Sleep(time.Duration(i+1) * time.Second) + + cfg := &Config{ + Host: g.Cfg().MustGet(ctx, "rabbitmq.host").String(), + Port: g.Cfg().MustGet(ctx, "rabbitmq.port").Int(), + Username: g.Cfg().MustGet(ctx, "rabbitmq.username").String(), + Password: g.Cfg().MustGet(ctx, "rabbitmq.password").String(), + VHost: g.Cfg().MustGet(ctx, "rabbitmq.vhost", "/").String(), + } + + url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s", + cfg.Username, + cfg.Password, + cfg.Host, + cfg.Port, + cfg.VHost, + ) + + var err error + conn, err = amqp.Dial(url) + if err != nil { + g.Log().Errorf(ctx, "重连失败 (尝试 %d/10): %v", i+1, err) + continue + } + + channel, err = conn.Channel() + if err != nil { + g.Log().Errorf(ctx, "创建 Channel 失败 (尝试 %d/10): %v", i+1, err) + continue + } + + g.Log().Info(ctx, "RabbitMQ 重连成功") + go handleConnectionClose(ctx) + return + } + + g.Log().Fatal(ctx, "RabbitMQ 重连失败,已达到最大重试次数") +} + +// Close 关闭连接 +func Close(ctx context.Context) error { + mu.Lock() + defer mu.Unlock() + + if channel != nil { + if err := channel.Close(); err != nil { + g.Log().Errorf(ctx, "关闭 RabbitMQ Channel 失败: %v", err) + } + } + + if conn != nil { + if err := conn.Close(); err != nil { + g.Log().Errorf(ctx, "关闭 RabbitMQ 连接失败: %v", err) + return err + } + } + + g.Log().Info(ctx, "RabbitMQ 连接已关闭") + return nil +} diff --git a/rabbitmq/consumer.go b/rabbitmq/consumer.go new file mode 100644 index 0000000..f9fe3ae --- /dev/null +++ b/rabbitmq/consumer.go @@ -0,0 +1,165 @@ +package rabbitmq + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/gogf/gf/v2/frame/g" + amqp "github.com/rabbitmq/amqp091-go" +) + +// MessageHandler 消息处理函数 +type MessageHandler func(ctx context.Context, body []byte) error + +// Consumer 消费者 +type Consumer struct { + queue string + consumerTag string + prefetchCount int // QoS: 预取数量(并发控制) + autoAck bool // 是否自动确认 + handler MessageHandler + workerCount int // worker 数量 +} + +// ConsumerOption 消费者配置选项 +type ConsumerOption func(*Consumer) + +// WithPrefetchCount 设置预取数量(并发控制) +func WithPrefetchCount(count int) ConsumerOption { + return func(c *Consumer) { + c.prefetchCount = count + } +} + +// WithAutoAck 设置自动确认 +func WithAutoAck(autoAck bool) ConsumerOption { + return func(c *Consumer) { + c.autoAck = autoAck + } +} + +// WithWorkerCount 设置 worker 数量 +func WithWorkerCount(count int) ConsumerOption { + return func(c *Consumer) { + c.workerCount = count + } +} + +// WithConsumerTag 设置消费者标签 +func WithConsumerTag(tag string) ConsumerOption { + return func(c *Consumer) { + c.consumerTag = tag + } +} + +// NewConsumer 创建消费者 +func NewConsumer(queue string, handler MessageHandler, opts ...ConsumerOption) *Consumer { + c := &Consumer{ + queue: queue, + consumerTag: "", + prefetchCount: 1, // 默认 1 个 + autoAck: false, // 默认手动确认 + handler: handler, + workerCount: 1, // 默认 1 个 worker + } + + // 应用选项 + for _, opt := range opts { + opt(c) + } + + return c +} + +// Start 启动消费者 +func (c *Consumer) Start(ctx context.Context) error { + ch, err := GetChannel() + if err != nil { + return err + } + + // 设置 QoS(并发控制) + err = ch.Qos( + c.prefetchCount, // prefetchCount: 每个 consumer 最多同时处理的消息数 + 0, // prefetchSize: 0 表示不限制 + false, // global: false 表示仅应用于当前 channel + ) + if err != nil { + return fmt.Errorf("设置 QoS 失败: %v", err) + } + + // 开始消费 + msgs, err := ch.Consume( + c.queue, // queue + c.consumerTag, // consumer tag + c.autoAck, // auto-ack + false, // exclusive + false, // no-local + false, // no-wait + nil, // args + ) + if err != nil { + return fmt.Errorf("开始消费失败: %v", err) + } + + g.Log().Infof(ctx, "消费者已启动: queue=%s, prefetch=%d, workers=%d", + c.queue, c.prefetchCount, c.workerCount) + + // 启动多个 worker + for i := 0; i < c.workerCount; i++ { + go c.worker(ctx, i, msgs) + } + + return nil +} + +// worker 工作协程 +func (c *Consumer) worker(ctx context.Context, workerID int, msgs <-chan amqp.Delivery) { + g.Log().Debugf(ctx, "Worker %d 已启动", workerID) + + for msg := range msgs { + // 处理消息 + err := c.handler(ctx, msg.Body) + + if err != nil { + g.Log().Errorf(ctx, "Worker %d 处理消息失败: %v", workerID, err) + + // 如果不是自动确认,需要手动 Nack + if !c.autoAck { + // requeue=false: 不重新入队,进入死信队列 + msg.Nack(false, false) + } + } else { + // 处理成功,手动确认 + if !c.autoAck { + msg.Ack(false) + } + + g.Log().Debugf(ctx, "Worker %d 处理消息成功", workerID) + } + } + + g.Log().Debugf(ctx, "Worker %d 已停止", workerID) +} + +// StartTypedConsumer 启动类型化消费者(自动反序列化) +func StartTypedConsumer[T any]( + ctx context.Context, + queue string, + handler func(ctx context.Context, msg *T) error, + opts ...ConsumerOption, +) error { + // 包装处理函数 + wrappedHandler := func(ctx context.Context, body []byte) error { + var msg T + if err := json.Unmarshal(body, &msg); err != nil { + return fmt.Errorf("反序列化消息失败: %v", err) + } + + return handler(ctx, &msg) + } + + consumer := NewConsumer(queue, wrappedHandler, opts...) + return consumer.Start(ctx) +} diff --git a/rabbitmq/publisher.go b/rabbitmq/publisher.go new file mode 100644 index 0000000..096744b --- /dev/null +++ b/rabbitmq/publisher.go @@ -0,0 +1,147 @@ +package rabbitmq + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/gogf/gf/v2/frame/g" + amqp "github.com/rabbitmq/amqp091-go" +) + +// Publisher 消息发布器 +type Publisher struct { + exchange string + routingKey string +} + +// NewPublisher 创建发布器 +func NewPublisher(exchange, routingKey string) *Publisher { + return &Publisher{ + exchange: exchange, + routingKey: routingKey, + } +} + +// Publish 发布消息 +func (p *Publisher) Publish(ctx context.Context, message interface{}) error { + ch, err := GetChannel() + if err != nil { + return err + } + + // 序列化消息 + body, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("消息序列化失败: %v", err) + } + + // 发布消息 + err = ch.PublishWithContext( + ctx, + p.exchange, // exchange + p.routingKey, // routing key + false, // mandatory + false, // immediate + amqp.Publishing{ + DeliveryMode: amqp.Persistent, // 持久化 + ContentType: "application/json", + Body: body, + }, + ) + + if err != nil { + g.Log().Errorf(ctx, "发布消息失败: exchange=%s, routingKey=%s, err=%v", + p.exchange, p.routingKey, err) + return err + } + + g.Log().Debugf(ctx, "消息发布成功: exchange=%s, routingKey=%s", + p.exchange, p.routingKey) + + return nil +} + +// PublishDelayed 发布延时消息 +// delaySeconds: 延时秒数 +func (p *Publisher) PublishDelayed(ctx context.Context, message interface{}, delaySeconds int) error { + ch, err := GetChannel() + if err != nil { + return err + } + + // 序列化消息 + body, err := json.Marshal(message) + if err != nil { + return fmt.Errorf("消息序列化失败: %v", err) + } + + // 发布延时消息(需要 rabbitmq_delayed_message_exchange 插件) + err = ch.PublishWithContext( + ctx, + p.exchange, // exchange(必须是 x-delayed-message 类型) + p.routingKey, // routing key + false, // mandatory + false, // immediate + amqp.Publishing{ + DeliveryMode: amqp.Persistent, + ContentType: "application/json", + Body: body, + Headers: amqp.Table{ + "x-delay": delaySeconds * 1000, // 延时(毫秒) + }, + }, + ) + + if err != nil { + g.Log().Errorf(ctx, "发布延时消息失败: exchange=%s, routingKey=%s, delay=%ds, err=%v", + p.exchange, p.routingKey, delaySeconds, err) + return err + } + + g.Log().Debugf(ctx, "延时消息发布成功: exchange=%s, routingKey=%s, delay=%ds", + p.exchange, p.routingKey, delaySeconds) + + return nil +} + +// PublishBatch 批量发布消息 +func (p *Publisher) PublishBatch(ctx context.Context, messages []interface{}) error { + if len(messages) == 0 { + return nil + } + + ch, err := GetChannel() + if err != nil { + return err + } + + for i, message := range messages { + body, err := json.Marshal(message) + if err != nil { + g.Log().Errorf(ctx, "消息 %d 序列化失败: %v", i, err) + continue + } + + err = ch.PublishWithContext( + ctx, + p.exchange, + p.routingKey, + false, + false, + amqp.Publishing{ + DeliveryMode: amqp.Persistent, + ContentType: "application/json", + Body: body, + }, + ) + + if err != nil { + g.Log().Errorf(ctx, "消息 %d 发布失败: %v", i, err) + continue + } + } + + g.Log().Infof(ctx, "批量发布完成: 共 %d 条消息", len(messages)) + return nil +} diff --git a/rabbitmq/setup.go b/rabbitmq/setup.go new file mode 100644 index 0000000..bc8bf59 --- /dev/null +++ b/rabbitmq/setup.go @@ -0,0 +1,231 @@ +package rabbitmq + +import ( + "context" + "fmt" + + "github.com/gogf/gf/v2/frame/g" + amqp "github.com/rabbitmq/amqp091-go" +) + +// QueueConfig 队列配置 +type QueueConfig struct { + Name string + Durable bool // 持久化 + AutoDelete bool // 自动删除 + Exclusive bool // 排他 + Args amqp.Table // 额外参数 +} + +// ExchangeConfig Exchange 配置 +type ExchangeConfig struct { + Name string + Type string // direct/topic/fanout/x-delayed-message + Durable bool + AutoDelete bool + Args amqp.Table +} + +// BindingConfig 绑定配置 +type BindingConfig struct { + Queue string + Exchange string + RoutingKey string + Args amqp.Table +} + +// DeclareQueue 声明队列 +func DeclareQueue(ctx context.Context, cfg *QueueConfig) error { + ch, err := GetChannel() + if err != nil { + return err + } + + _, err = ch.QueueDeclare( + cfg.Name, + cfg.Durable, + cfg.AutoDelete, + cfg.Exclusive, + false, // no-wait + cfg.Args, + ) + + if err != nil { + g.Log().Errorf(ctx, "声明队列失败: %s, err=%v", cfg.Name, err) + return err + } + + g.Log().Infof(ctx, "队列声明成功: %s", cfg.Name) + return nil +} + +// DeclareExchange 声明 Exchange +func DeclareExchange(ctx context.Context, cfg *ExchangeConfig) error { + ch, err := GetChannel() + if err != nil { + return err + } + + err = ch.ExchangeDeclare( + cfg.Name, + cfg.Type, + cfg.Durable, + cfg.AutoDelete, + false, // internal + false, // no-wait + cfg.Args, + ) + + if err != nil { + g.Log().Errorf(ctx, "声明 Exchange 失败: %s, err=%v", cfg.Name, err) + return err + } + + g.Log().Infof(ctx, "Exchange 声明成功: %s (type=%s)", cfg.Name, cfg.Type) + return nil +} + +// BindQueue 绑定队列到 Exchange +func BindQueue(ctx context.Context, cfg *BindingConfig) error { + ch, err := GetChannel() + if err != nil { + return err + } + + err = ch.QueueBind( + cfg.Queue, + cfg.RoutingKey, + cfg.Exchange, + false, // no-wait + cfg.Args, + ) + + if err != nil { + g.Log().Errorf(ctx, "绑定队列失败: queue=%s, exchange=%s, routingKey=%s, err=%v", + cfg.Queue, cfg.Exchange, cfg.RoutingKey, err) + return err + } + + g.Log().Infof(ctx, "队列绑定成功: queue=%s → exchange=%s (routingKey=%s)", + cfg.Queue, cfg.Exchange, cfg.RoutingKey) + return nil +} + +// SetupDelayExchange 设置延时 Exchange(需要 rabbitmq_delayed_message_exchange 插件) +func SetupDelayExchange(ctx context.Context, exchangeName string) error { + return DeclareExchange(ctx, &ExchangeConfig{ + Name: exchangeName, + Type: "x-delayed-message", + Durable: true, + Args: amqp.Table{ + "x-delayed-type": "direct", + }, + }) +} + +// SetupDeadLetterQueue 设置死信队列 +func SetupDeadLetterQueue(ctx context.Context, queueName, exchangeName string) error { + // 1. 声明死信 Exchange + err := DeclareExchange(ctx, &ExchangeConfig{ + Name: exchangeName, + Type: "direct", + Durable: true, + }) + if err != nil { + return err + } + + // 2. 声明死信队列 + err = DeclareQueue(ctx, &QueueConfig{ + Name: queueName, + Durable: true, + }) + if err != nil { + return err + } + + // 3. 绑定 + return BindQueue(ctx, &BindingConfig{ + Queue: queueName, + Exchange: exchangeName, + RoutingKey: queueName, + }) +} + +// SetupQueueWithDLX 创建带死信队列的普通队列 +func SetupQueueWithDLX(ctx context.Context, queueName, dlxExchange, dlxRoutingKey string) error { + return DeclareQueue(ctx, &QueueConfig{ + Name: queueName, + Durable: true, + Args: amqp.Table{ + "x-dead-letter-exchange": dlxExchange, + "x-dead-letter-routing-key": dlxRoutingKey, + }, + }) +} + +// SetupBasicTopology 设置基础拓扑(适用于小红书客服场景) +func SetupBasicTopology(ctx context.Context) error { + // 1. 声明普通 Exchange + err := DeclareExchange(ctx, &ExchangeConfig{ + Name: "ragflow_exchange", + Type: "direct", + Durable: true, + }) + if err != nil { + return err + } + + // 2. 声明延时 Exchange + err = SetupDelayExchange(ctx, "delay_exchange") + if err != nil { + return fmt.Errorf("延时 Exchange 声明失败(可能未安装插件): %v", err) + } + + // 3. 声明死信队列 + err = SetupDeadLetterQueue(ctx, "dead_letter_queue", "dlx_exchange") + if err != nil { + return err + } + + // 4. 声明业务队列 + queues := []struct { + name string + dlx bool // 是否需要死信队列 + }{ + {"ragflow_request_queue", true}, + {"follow_up_queue", true}, + {"archive_queue", true}, + } + + for _, q := range queues { + if q.dlx { + err = SetupQueueWithDLX(ctx, q.name, "dlx_exchange", "dead_letter_queue") + } else { + err = DeclareQueue(ctx, &QueueConfig{ + Name: q.name, + Durable: true, + }) + } + if err != nil { + return err + } + } + + // 5. 绑定队列 + bindings := []BindingConfig{ + {Queue: "ragflow_request_queue", Exchange: "ragflow_exchange", RoutingKey: "ragflow_request_queue"}, + {Queue: "follow_up_queue", Exchange: "delay_exchange", RoutingKey: "follow_up_queue"}, + {Queue: "archive_queue", Exchange: "delay_exchange", RoutingKey: "archive_queue"}, + } + + for _, b := range bindings { + err = BindQueue(ctx, &b) + if err != nil { + return err + } + } + + g.Log().Info(ctx, "RabbitMQ 拓扑结构设置完成") + return nil +} diff --git a/ragflow/client.go b/ragflow/client.go index c340f73..bdbcd39 100644 --- a/ragflow/client.go +++ b/ragflow/client.go @@ -9,27 +9,79 @@ import ( "net/url" "strings" + "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/gclient" + "github.com/gogf/gf/v2/os/gcfg" ) +var ( + // globalClient 全局 RAGFlow 客户端(单例,自动初始化) + globalClient *Client +) + +// init 包初始化时自动创建全局客户端 +func init() { + ctx := context.Background() + + // 读取配置 + baseURL, apiKey := loadConfig(ctx) + + // 如果配置不完整,跳过初始化 + if baseURL == "" || apiKey == "" { + g.Log().Warning(ctx, "⚠️ RAGFlow 配置未找到,请在 common/ragflow/config.yaml 中配置") + return + } + + // 初始化全局客户端 + httpClient := gclient.New() + httpClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + httpClient.SetHeader("Content-Type", "application/json") + + globalClient = &Client{ + BaseURL: strings.TrimSuffix(baseURL, "/"), + APIKey: apiKey, + HTTPClient: httpClient, + } + + g.Log().Infof(ctx, "✅ RAGFlow 全局客户端初始化成功: baseURL=%s", baseURL) +} + +// loadConfig 从配置文件加载 RAGFlow 配置 +func loadConfig(ctx context.Context) (baseURL, apiKey string) { + // 创建配置实例 + cfg, err := gcfg.New() + if err != nil { + g.Log().Debugf(ctx, "创建配置实例失败: %v", err) + return "", "" + } + + // 设置配置文件 + adapter, ok := cfg.GetAdapter().(*gcfg.AdapterFile) + if !ok { + g.Log().Debug(ctx, "配置适配器类型不匹配") + return "", "" + } + + adapter.SetFileName("config.yaml") + + // 读取配置项 + baseURL = cfg.MustGet(ctx, "ragflow.base_url").String() + apiKey = cfg.MustGet(ctx, "ragflow.api_key").String() + + return baseURL, apiKey +} + +// GetGlobalClient 获取全局客户端 +// 使用示例:client := ragflow.GetGlobalClient() +func GetGlobalClient() *Client { + return globalClient +} + // Client RAGFlow API 客户端 type Client struct { BaseURL string APIKey string - HTTPClient *gclient.Client -} - -// NewClient 创建新的 RAGFlow 客户端 -func NewClient(baseURL, apiKey string) *Client { - client := gclient.New() - client.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey)) - client.SetHeader("Content-Type", "application/json") - - return &Client{ - BaseURL: strings.TrimSuffix(baseURL, "/"), - APIKey: apiKey, - HTTPClient: client, - } + HTTPClient *gclient.Client // HTTP 客户端 } // CommonResponse 通用响应结构 diff --git a/ragflow/config.yaml b/ragflow/config.yaml new file mode 100644 index 0000000..95c7e13 --- /dev/null +++ b/ragflow/config.yaml @@ -0,0 +1,10 @@ +# RAGFlow 配置文件 +# 用于全局客户端自动初始化 + +ragflow: + # RAGFlow 服务地址 + base_url: "http://localhost:9380" + + # RAGFlow API Key + # 获取方式:登录 RAGFlow 管理界面 -> 设置 -> API Keys + api_key: "ragflow-your-api-key-here"