diff --git a/message/message.go b/message/message.go index e2c72c1..321fc5d 100644 --- a/message/message.go +++ b/message/message.go @@ -7,11 +7,6 @@ import ( "github.com/gogf/gf/v2/errors/gerror" ) -// GetRedisClient 获取 Redis 客户端(供外部使用) -func GetRedisClient() *gredis.Redis { - return getRedisClient() -} - func GetRedisClientTest(name string) *gredis.Redis { return getRedisClientTest(name) } diff --git a/message/msg_queue.go b/message/msg_queue.go new file mode 100644 index 0000000..be1f2da --- /dev/null +++ b/message/msg_queue.go @@ -0,0 +1,152 @@ +package message + +import ( + "context" + "fmt" +) + +type RedisConfig struct { + // Stream 名称 + Stream string + + // 消费者组名称 + Group string + + // 消费者名称 + Consumer string + + // 每次消费数量 + Count int64 + + // 是否自动 ACK + AutoAck bool + + // 处理函数 + HandleFunc func(ctx context.Context, message map[string]interface{}) error +} + +// RabbitMQConfig RabbitMQ 队列配置 +type RabbitMQConfig struct { + Mode string + Exchange string + Topic string + DelayMessage bool + + // 队列名称(必需) + Name string + + // 实际队列名(用于绑定) + Queue string + + // 是否持久化 + Durable bool + + // QoS 预取数量(每次推送的消息数量,默认10) + PrefetchCount int + + // 最大重试次数(默认3) + MaxRetry int + + // 是否自动 ACK + AutoAck bool + + // 处理函数 + HandleFunc func(ctx context.Context, message map[string]interface{}) error +} + +// NATSConfig NATS 队列配置 +type NATSConfig struct { + DelayMessage bool + // Stream 名称 + Stream string + + // 消费者名称 + Consumer string + + // 是否持久化 + Durable bool + + // 副本数 + Replicas int + // QoS 预取数量(每次推送的消息数量,默认10) + PrefetchCount int + + // 是否自动 ACK + AutoAck bool + + // 处理函数 + HandleFunc func(ctx context.Context, message map[string]interface{}) error +} + +// messageBroker 消息代理接口 +type messageBroker interface { + // StreamGroup 创建消费组(支持单个配置或批量配置) + streamGroup(ctx context.Context, configs ...interface{}) error + + // Publish 发布消息(支持单个配置或批量配置) + publish(ctx context.Context, config interface{}, data interface{}) error + + // PublishDelayed 发布延迟消息(支持单个配置或批量配置) + publishDelayed(ctx context.Context, config interface{}, data interface{}, delay int) error + + // Subscribe 订阅消息(支持单个配置或批量配置) + subscribe(ctx context.Context, configs ...interface{}) error +} + +type messageClientType string + +const ( + ClientTypeRedis messageClientType = "redis" + ClientTypeRabbitMQ messageClientType = "rabbitmq" + ClientTypeNATS messageClientType = "nats" +) + +// newMessageBroker 创建消息代理实例 +func newMessageBroker(ctx context.Context, clientType messageClientType) (messageBroker, error) { + switch clientType { + case ClientTypeRedis: + return &redisMessageClient{clientType: clientType}, nil + case ClientTypeRabbitMQ: + return &rabbitMQMessageClient{clientType: clientType}, nil + case ClientTypeNATS: + return &natsMessageClient{clientType: clientType}, nil + default: + return nil, fmt.Errorf("unknown client type: %s", clientType) + } +} + +// StreamGroup 直接创建消费组 +func StreamGroup(ctx context.Context, clientType messageClientType, configs ...interface{}) error { + broker, err := newMessageBroker(ctx, clientType) + if err != nil { + return err + } + return broker.streamGroup(ctx, configs...) +} + +// Publish 直接发布消息 +func Publish(ctx context.Context, clientType messageClientType, config interface{}, data interface{}) error { + broker, err := newMessageBroker(ctx, clientType) + if err != nil { + return err + } + return broker.publish(ctx, config, data) +} + +// PublishDelayed 直接发布延迟消息 +func PublishDelayed(ctx context.Context, clientType messageClientType, config interface{}, data interface{}, delay int) error { + broker, err := newMessageBroker(ctx, clientType) + if err != nil { + return err + } + return broker.publishDelayed(ctx, config, data, delay) +} + +// Subscribe 直接订阅消息 +func Subscribe(ctx context.Context, clientType messageClientType, configs ...interface{}) error { + broker, err := newMessageBroker(ctx, clientType) + if err != nil { + return err + } + return broker.subscribe(ctx, configs...) +} diff --git a/nats/connection.go b/message/nats_client.go similarity index 79% rename from nats/connection.go rename to message/nats_client.go index eb8078a..470bef7 100644 --- a/nats/connection.go +++ b/message/nats_client.go @@ -1,4 +1,4 @@ -package nats +package message import ( "context" @@ -16,7 +16,7 @@ var ( nc *nats.Conn js jetstream.JetStream inited bool - mu sync.RWMutex + natsMu sync.RWMutex natsURL string healthCtx context.Context healthCancel context.CancelFunc @@ -24,15 +24,15 @@ var ( reconnectChan chan struct{} // 连接状态变化监听器 - connStateListeners []ConnStateListener + connStateListeners []connStateListener connListenersMu sync.RWMutex // 监控指标 - metrics Metrics + metrics metricsCounter ) // Metrics 监控指标 -type Metrics struct { +type metricsCounter struct { PublishCount atomic.Int64 PublishError atomic.Int64 SubscribeCount atomic.Int64 @@ -43,33 +43,33 @@ type Metrics struct { } // ConnState 连接状态 -type ConnState int +type connState int const ( - ConnStateDisconnected ConnState = iota - ConnStateConnecting - ConnStateConnected - ConnStateReconnecting - ConnStateClosed + connStateDisconnected connState = iota + connStateConnecting + connStateConnected + connStateReconnecting + connStateClosed ) // ConnStateListener 连接状态监听器 -type ConnStateListener func(state ConnState, err error) +type connStateListener func(state connState, err error) // GetMetrics 获取监控指标 -func GetMetrics() Metrics { +func getMetrics() metricsCounter { return metrics } -// RegisterConnStateListener 注册连接状态监听器 -func RegisterConnStateListener(listener ConnStateListener) { +// registerConnStateListener 注册连接状态监听器 +func registerConnStateListener(listener connStateListener) { connListenersMu.Lock() defer connListenersMu.Unlock() connStateListeners = append(connStateListeners, listener) } -// UnregisterConnStateListener 取消注册连接状态监听器 -func UnregisterConnStateListener(listener ConnStateListener) { +// unregisterConnStateListener 取消注册连接状态监听器 +func unregisterConnStateListener(listener connStateListener) { connListenersMu.Lock() defer connListenersMu.Unlock() for i, l := range connStateListeners { @@ -81,9 +81,9 @@ func UnregisterConnStateListener(listener ConnStateListener) { } // notifyConnState 通知所有监听器连接状态变化 -func notifyConnState(state ConnState, err error) { +func notifyConnState(state connState, err error) { connListenersMu.RLock() - listeners := make([]ConnStateListener, len(connStateListeners)) + listeners := make([]connStateListener, len(connStateListeners)) copy(listeners, connStateListeners) connListenersMu.RUnlock() @@ -119,17 +119,17 @@ func init() { // initConnection 初始化连接 func initConnection() { ctx := context.Background() - notifyConnState(ConnStateConnecting, nil) + notifyConnState(connStateConnecting, nil) if err := connect(ctx); err != nil { g.Log().Errorf(ctx, "NATS 初始连接失败: %v", err) - notifyConnState(ConnStateDisconnected, err) + notifyConnState(connStateDisconnected, err) } } // connect 建立 NATS 连接 func connect(ctx context.Context) error { - mu.Lock() - defer mu.Unlock() + natsMu.Lock() + defer natsMu.Unlock() if nc != nil && !nc.IsClosed() { nc.Close() @@ -152,7 +152,7 @@ func connect(ctx context.Context) error { } // 通知重连成功 - notifyConnState(ConnStateConnected, nil) + notifyConnState(connStateConnected, nil) // 使用非阻塞发送避免阻塞 select { @@ -164,12 +164,12 @@ func connect(ctx context.Context) error { nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { g.Log().Warningf(ctx, "⚠️ NATS 连接断开: %v, 准备重连...", err) connected = false - notifyConnState(ConnStateReconnecting, err) + notifyConnState(connStateReconnecting, err) }), nats.ClosedHandler(func(nc *nats.Conn) { g.Log().Infof(ctx, "NATS 连接已关闭: %s", nc.ConnectedUrl()) connected = false - notifyConnState(ConnStateClosed, nil) + notifyConnState(connStateClosed, nil) }), nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { g.Log().Errorf(ctx, "NATS 错误: %v", err) @@ -186,7 +186,7 @@ func connect(ctx context.Context) error { if nc.Status() != nats.CONNECTED { select { case <-time.After(5 * time.Second): - notifyConnState(ConnStateDisconnected, fmt.Errorf("连接超时")) + notifyConnState(connStateDisconnected, fmt.Errorf("连接超时")) return fmt.Errorf("NATS 连接超时") case <-nc.StatusChanged(nats.CONNECTED): } @@ -201,7 +201,7 @@ func connect(ctx context.Context) error { connected = true inited = true g.Log().Infof(ctx, "✅ NATS 连接成功: %s", nc.ConnectedUrl()) - notifyConnState(ConnStateConnected, nil) + notifyConnState(connStateConnected, nil) return nil } @@ -215,10 +215,10 @@ func healthCheck() { case <-healthCtx.Done(): return case <-ticker.C: - mu.RLock() + natsMu.RLock() currentConnected := connected currentConn := nc - mu.RUnlock() + natsMu.RUnlock() if !currentConnected || currentConn == nil || currentConn.IsClosed() { // 仅记录日志,不尝试重连(NATS 已有自动重连机制) @@ -233,38 +233,33 @@ func healthCheck() { // checkConnected 检查连接状态 func checkConnected() bool { - mu.RLock() - defer mu.RUnlock() + natsMu.RLock() + defer natsMu.RUnlock() return connected && nc != nil && !nc.IsClosed() } -// IsConnected 检查 NATS 是否已连接 -func IsConnected() bool { - return checkConnected() -} - -// GetConnState 获取当前连接状态 -func GetConnState() ConnState { - mu.RLock() - defer mu.RUnlock() +// getConnState 获取当前连接状态 +func getConnState() connState { + natsMu.RLock() + defer natsMu.RUnlock() if nc == nil { - return ConnStateDisconnected + return connStateDisconnected } if nc.IsClosed() { - return ConnStateClosed + return connStateClosed } if connected { - return ConnStateConnected + return connStateConnected } - return ConnStateDisconnected + return connStateDisconnected } -// Shutdown 优雅关闭:自动注销所有已注册的服务并关闭 NATS 连接 -func Shutdown() error { +// shutdown 优雅关闭:自动注销所有已注册的服务并关闭 NATS 连接 +func shutdown() error { ctx := context.Background() g.Log().Info(ctx, "开始优雅关闭 NATS RPC 服务...") @@ -299,8 +294,8 @@ func Shutdown() error { g.Log().Infof(ctx, "已注销 %d 个单实例服务和 %d 个队列服务", singleServiceCount, queueServiceCount) - mu.Lock() - defer mu.Unlock() + natsMu.Lock() + defer natsMu.Unlock() // 停止健康检查协程 if healthCancel != nil { diff --git a/nats/nats_rpc.go b/message/nats_rpc.go similarity index 89% rename from nats/nats_rpc.go rename to message/nats_rpc.go index 8cbb8ba..8644734 100644 --- a/nats/nats_rpc.go +++ b/message/nats_rpc.go @@ -1,4 +1,4 @@ -package nats +package message import ( "context" @@ -7,6 +7,7 @@ import ( "fmt" "github.com/gogf/gf/v2/frame/g" "github.com/nats-io/nats.go" + "go.opentelemetry.io/otel/trace" "reflect" "sync" ) @@ -509,12 +510,12 @@ func AutoRegisterServices(ctx context.Context, serviceInstances map[string]inter return fmt.Errorf("未能注册任何服务") } // 设置取消监听器(监听基于 TraceID 的取消请求) - //if _, err := setupCancelListener(ctx); err != nil { - // g.Log().Errorf(ctx, "设置取消监听器失败: %v", err) - //} else { - // g.Log().Infof(ctx, "✅ 取消监听器已自动设置") - //} - //g.Log().Infof(ctx, "✅ 共自动注册了 %d 个服务", totalRegistered) + if _, err := setupCancelListener(ctx); err != nil { + g.Log().Errorf(ctx, "设置取消监听器失败: %v", err) + } else { + g.Log().Infof(ctx, "✅ 取消监听器已自动设置") + } + g.Log().Infof(ctx, "✅ 共自动注册了 %d 个服务", totalRegistered) return nil } @@ -671,3 +672,81 @@ func registerService(service interface{}, serviceNamePrefix string, options ...R g.Log().Infof(context.Background(), "✅ Service %v 共注册了 %d 个 RPC 方法", serviceNamePrefix, registeredCount) return nil } + +// ============ 上下文元数据工具函数 ============ +// 以下函数用于在 context 和 NATS 消息头之间互转元数据 + +// 定义常见的上下文元数据 key +const ( + TraceIDKey = "trace_id" + TokenKey = "token" +) + +func getTraceID(ctx context.Context) (traceID string, err error) { + // 提取 traceId:首先尝试从 OpenTelemetry Span 中提取,从 context 中提取 TraceID + span := trace.SpanFromContext(ctx) + if span != nil && span.SpanContext().HasTraceID() { + traceID = span.SpanContext().TraceID().String() + } else if tid := ctx.Value(TraceIDKey); tid != nil { + traceID = fmt.Sprintf("%v", tid) + } + if traceID == "" { + return traceID, fmt.Errorf("context 中没有 TraceID") + } + return +} + +// contextToHeaders 将 context 中的元数据转换为 NATS 消息头 +// 支持提取 user_id、tenant_id、trace_id、token 等常见字段 +func contextToHeaders(ctx context.Context) (nats.Header, error) { + headers := make(nats.Header) + + // 提取 traceId:首先尝试从 OpenTelemetry Span 中提取 + if traceID, err := getTraceID(ctx); err != nil { + return headers, err + } else { + headers.Set(TraceIDKey, traceID) + } + + // 提取 token(优先级:context value > HTTP Authorization header) + token := "" + if t := ctx.Value(TokenKey); t != nil { + token = fmt.Sprintf("%v", t) + } else if r := g.RequestFromCtx(ctx); r != nil { + // 从 HTTP 请求的 Authorization header 中提取 token + auth := r.GetHeader("Authorization") + if auth != "" { + // 移除 "Bearer " 前缀 + if len(auth) > 7 && auth[:7] == "Bearer " { + token = auth[7:] + } else { + token = auth + } + } + } + if token != "" { + headers.Set(TokenKey, token) + } + + return headers, nil +} + +// headersToContext 从 NATS 消息头重建 context +// 支持还原 user_id、tenant_id、trace_id、token 等字段 +func headersToContext(ctx context.Context, headers nats.Header) context.Context { + if headers == nil { + return ctx + } + + // 恢复 trace_id + if traceID := headers.Get(TraceIDKey); traceID != "" { + ctx = context.WithValue(ctx, TraceIDKey, traceID) + } + + // 恢复 token + if token := headers.Get(TokenKey); token != "" { + ctx = context.WithValue(ctx, TokenKey, token) + } + + return ctx +} diff --git a/message/rabbit.go b/message/rabbit.go index bd6e010..1ef01fb 100644 --- a/message/rabbit.go +++ b/message/rabbit.go @@ -22,7 +22,7 @@ var ( ) // Config RabbitMQ 配置 -type RabbitMQConfig struct { +type RabbitMQConfig1 struct { Host string Port int Username string @@ -31,8 +31,8 @@ type RabbitMQConfig struct { } // rabbitMQConfig 默认配置 -func getRabbitMQConfig() *RabbitMQConfig { - return &RabbitMQConfig{ +func getRabbitMQConfig() *RabbitMQConfig1 { + return &RabbitMQConfig1{ Host: g.Cfg().MustGet(context.Background(), "rabbitmq.host").String(), Port: g.Cfg().MustGet(context.Background(), "rabbitmq.port").Int(), Username: g.Cfg().MustGet(context.Background(), "rabbitmq.username").String(), diff --git a/rabbitmq/client.go b/message/rabbitmq_client.go similarity index 93% rename from rabbitmq/client.go rename to message/rabbitmq_client.go index 68eda88..fc225f7 100644 --- a/rabbitmq/client.go +++ b/message/rabbitmq_client.go @@ -1,4 +1,4 @@ -package rabbitmq +package message import ( "context" @@ -14,8 +14,8 @@ import ( var ( conn *amqp.Connection channel *amqp.Channel - once sync.Once - mu sync.RWMutex + rabbitmqOnce sync.Once + rabbitmqMu sync.RWMutex closeWatcher chan struct{} // 用于停止监听 goroutine watcherStarted bool // 防止重复启动监听 ) @@ -32,7 +32,7 @@ type Config struct { // Init 初始化 RabbitMQ 连接 func Init(ctx context.Context, cfg *Config) error { var err error - once.Do(func() { + rabbitmqOnce.Do(func() { // 构建连接字符串 url := "amqp://" + cfg.Username + ":" + cfg.Password + "@" + cfg.Host + ":" + gconv.String(cfg.Port) + "/" + cfg.VHost @@ -80,8 +80,8 @@ func InitFromConfig(ctx context.Context) error { // GetChannel 获取 Channel func GetChannel() (*amqp.Channel, error) { - mu.RLock() - defer mu.RUnlock() + rabbitmqMu.RLock() + defer rabbitmqMu.RUnlock() if channel == nil || channel.IsClosed() { return nil, gerror.New("RabbitMQ Channel 未初始化或已关闭") @@ -92,8 +92,8 @@ func GetChannel() (*amqp.Channel, error) { // GetConnection 获取连接 func GetConnection() (*amqp.Connection, error) { - mu.RLock() - defer mu.RUnlock() + rabbitmqMu.RLock() + defer rabbitmqMu.RUnlock() if conn == nil || conn.IsClosed() { return nil, gerror.New("RabbitMQ 连接未初始化或已关闭") @@ -113,9 +113,9 @@ func handleConnectionClose(ctx context.Context) { default: } - mu.RLock() + rabbitmqMu.RLock() currentConn := conn - mu.RUnlock() + rabbitmqMu.RUnlock() if currentConn == nil { return @@ -141,8 +141,8 @@ func handleConnectionClose(ctx context.Context) { // reconnect 重新连接 func reconnect(ctx context.Context) { - mu.Lock() - defer mu.Unlock() + rabbitmqMu.Lock() + defer rabbitmqMu.Unlock() for i := 0; i < 10; i++ { time.Sleep(time.Duration(i+1) * time.Second) @@ -180,8 +180,8 @@ func reconnect(ctx context.Context) { // Close 关闭连接 func Close(ctx context.Context) (err error) { - mu.Lock() - defer mu.Unlock() + rabbitmqMu.Lock() + defer rabbitmqMu.Unlock() // 停止监听 goroutine if closeWatcher != nil { diff --git a/message/redis.go b/message/redis.go index 7d5c76b..67e88f1 100644 --- a/message/redis.go +++ b/message/redis.go @@ -3,6 +3,7 @@ package message import ( "context" "errors" + "fmt" "strings" "time" @@ -18,11 +19,6 @@ type StreamMessage struct { Values map[string]interface{} // 消息内容 } -// getClient 获取 Redis 客户端 -func getRedisClient() *gredis.Redis { - return g.Redis() -} - // getClient 获取 Redis 客户端 func getRedisClientTest(name string) *gredis.Redis { return g.Redis(name) @@ -47,46 +43,66 @@ func getRedisClientByDB(db int) *gredis.Redis { // lock 分布式锁 func lock(ctx context.Context, key string, expireSeconds int64, fn func(ctx context.Context) error) (success bool, err error) { - limit := 3 -LOOP: - if limit < 0 { - return false, errors.New("锁重试次数耗尽") + ds, err := GetManager().GetDefaultDataSource() + if err != nil { + return false, fmt.Errorf("获取默认数据源失败: %w", err) } - limit-- - if val, err := getRedisClient().Set(ctx, key, true, gredis.SetOption{ - TTLOption: gredis.TTLOption{ - EX: &expireSeconds, - }, - NX: true, - }); err != nil { - return false, err - } else { - if val.Bool() { - defer func(RedisClient *gredis.Redis, ctx context.Context, key string) { - if _, err = RedisClient.Del(ctx, key); err != nil { - glog.Errorf(ctx, "RedisClient.Del error: %v", err) - } - }(getRedisClient(), ctx, key) - if err = fn(ctx); err != nil { - return false, err - } - return true, nil + + maxRetries := 3 + for i := 0; i < maxRetries; i++ { + if val, err := ds.Redis().Set(ctx, key, true, gredis.SetOption{ + TTLOption: gredis.TTLOption{ + EX: &expireSeconds, + }, + NX: true, + }); err != nil { + return false, err } else { - time.Sleep(time.Second) - goto LOOP + if val.Bool() { + defer func(redisClient *gredis.Redis, ctx context.Context, key string) { + if _, err = redisClient.Del(ctx, key); err != nil { + glog.Errorf(ctx, "RedisClient.Del error: %v", err) + } + }(ds.Redis(), ctx, key) + if err = fn(ctx); err != nil { + return false, err + } + return true, nil + } else { + // 检查上下文是否已取消 + if ctx.Err() != nil { + return false, ctx.Err() + } + // 非最后一次重试时才等待 + if i < maxRetries-1 { + time.Sleep(time.Second) + } + } } } + return false, errors.New("锁重试次数耗尽") } // publishToRedis 将消息添加到 Redis Stream func publishToRedis(ctx context.Context, streamKey string, msg interface{}) (messageID string, err error) { + ds, err := GetManager().GetDefaultDataSource() + if err != nil { + return "", fmt.Errorf("获取默认数据源失败: %w", err) + } + + if !ds.IsConnected() { + if err := ds.Reconnect(ctx); err != nil { + return "", fmt.Errorf("redis重连失败: %w", err) + } + } + values := gconv.Map(msg) args := make([]interface{}, 0, len(values)*2+2) args = append(args, streamKey, "*") for key, val := range values { args = append(args, key, val) } - result, err := getRedisClient().Do(ctx, "XADD", args...) + result, err := ds.Redis().Do(ctx, "XADD", args...) if err != nil { return } @@ -96,7 +112,18 @@ func publishToRedis(ctx context.Context, streamKey string, msg interface{}) (mes // initStreamGroup 初始化 Stream 和消费者组 func initStreamGroup(ctx context.Context, streamKey, groupName string) error { - _, err := getRedisClient().Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM") + ds, err := GetManager().GetDefaultDataSource() + if err != nil { + return fmt.Errorf("获取默认数据源失败: %w", err) + } + + if !ds.IsConnected() { + if err := ds.Reconnect(ctx); err != nil { + return fmt.Errorf("redis重连失败: %w", err) + } + } + + _, err = ds.Redis().Do(ctx, "XGROUP", "CREATE", streamKey, groupName, "0", "MKSTREAM") if err != nil { // 如果组已存在,忽略错误 errStr := err.Error() @@ -113,6 +140,11 @@ func initStreamGroup(ctx context.Context, streamKey, groupName string) error { // readFromStream 从 Stream 读取消息 func readFromStream(ctx context.Context, msg QueueMessage) error { + ds, err := GetManager().GetDefaultDataSource() + if err != nil { + return fmt.Errorf("获取默认数据源失败: %w", err) + } + // 初始化 Stream 和消费者组 if err := initStreamGroup(ctx, msg.StreamKey, msg.GroupName); err != nil { return err @@ -120,7 +152,7 @@ func readFromStream(ctx context.Context, msg QueueMessage) error { go func() { RECONNECT: for { - result, err := getRedisClient().Do(ctx, "XREADGROUP", "GROUP", msg.GroupName, msg.ConsumerName, "COUNT", msg.BatchSize, "BLOCK", 0, "STREAMS", msg.StreamKey, ">") + result, err := ds.Redis().Do(ctx, "XREADGROUP", "GROUP", msg.GroupName, msg.ConsumerName, "COUNT", msg.BatchSize, "BLOCK", 0, "STREAMS", msg.StreamKey, ">") if err != nil { //select { //case <-ctx.Done(): @@ -222,11 +254,22 @@ func readFromStream(ctx context.Context, msg QueueMessage) error { // ackMessage 确认消息已处理 func ackMessage(ctx context.Context, streamKey, groupName string, messageIDs ...string) error { + ds, err := GetManager().GetDefaultDataSource() + if err != nil { + return fmt.Errorf("获取默认数据源失败: %w", err) + } + + if !ds.IsConnected() { + if err := ds.Reconnect(ctx); err != nil { + return fmt.Errorf("redis重连失败: %w", err) + } + } + args := make([]interface{}, 0, len(messageIDs)+2) args = append(args, streamKey, groupName) for _, id := range messageIDs { args = append(args, id) } - _, err := getRedisClient().Do(ctx, "XACK", args...) + _, err = ds.Redis().Do(ctx, "XACK", args...) return err } diff --git a/message/redis_client.go b/message/redis_client.go new file mode 100644 index 0000000..2db78cb --- /dev/null +++ b/message/redis_client.go @@ -0,0 +1,468 @@ +// ============================================================================= +// Redis 数据源连接管理 +// 使用 GoFrame 框架自带的 Redis 客户端,负责数据源的连接、重连、健康检查和优雅关闭 +// ============================================================================= + +package message + +import ( + "context" + "fmt" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/gogf/gf/v2/database/gredis" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/util/gconv" +) + +// ============================================================================= +// 数据源配置结构 +// ============================================================================= + +type RedisDataSourceConfig struct { + Name string `json:"name"` // 数据源名称 + Address string `json:"address"` // Redis 地址,如: 127.0.0.1:6379 + Db int `json:"db"` // 数据库编号 + Pass string `json:"pass"` // 密码 + Timeout time.Duration `json:"timeout"` // 连接超时 + MaxIdle int `json:"maxIdle"` // 最大空闲连接数 + MaxOpen int `json:"maxOpen"` // 最大活跃连接数 +} + +// ============================================================================= +// 单个数据源接口 +// ============================================================================= + +type DataSource interface { + Name() string + Redis() *gredis.Redis + IsConnected() bool + Connect(ctx context.Context) error + Reconnect(ctx context.Context) error + Close(ctx context.Context) error +} + +// ============================================================================= +// 数据源实现 +// ============================================================================= + +type BaseDataSource struct { + config *RedisDataSourceConfig + redis *gredis.Redis + isConnected bool + mu sync.RWMutex + lastError error + lastErrorTime time.Time + metrics RedisMetrics +} + +func NewBaseDataSource(config *RedisDataSourceConfig) *BaseDataSource { + return &BaseDataSource{ + config: config, + isConnected: false, + } +} + +func (d *BaseDataSource) Name() string { + return d.config.Name +} + +func (d *BaseDataSource) Redis() *gredis.Redis { + d.mu.RLock() + defer d.mu.RUnlock() + return d.redis +} + +func (d *BaseDataSource) IsConnected() bool { + d.mu.RLock() + defer d.mu.RUnlock() + return d.isConnected && d.redis != nil +} + +func (d *BaseDataSource) Connect(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + + // 设置默认值 + config := d.config + if config.Timeout == 0 { + config.Timeout = 10 * time.Second + } + if config.MaxIdle == 0 { + config.MaxIdle = 10 + } + if config.MaxOpen == 0 { + config.MaxOpen = 100 + } + + // 构建 GoFrame Redis 配置 + redisConfig := &gredis.Config{ + Address: config.Address, + Db: config.Db, + Pass: config.Pass, + } + + // 使用 GoFrame 的 Redis 连接 + redisObj, err := gredis.New(redisConfig) + if err != nil { + d.isConnected = false + d.lastError = err + d.lastErrorTime = time.Now() + d.metrics.PingError.Add(1) + return fmt.Errorf("datasource [%s] connection failed: %w", d.config.Name, err) + } + + d.redis = redisObj + + // 测试连接 + if err := d.Ping(ctx); err != nil { + d.isConnected = false + d.lastError = err + d.lastErrorTime = time.Now() + return fmt.Errorf("datasource [%s] ping failed: %w", d.config.Name, err) + } + + d.isConnected = true + d.lastError = nil + glog.Infof(ctx, "✅ datasource [%s] connected successfully", d.config.Name) + return nil +} + +func (d *BaseDataSource) Ping(ctx context.Context) error { + defer func() { + if r := recover(); r != nil { + d.metrics.PingError.Add(1) + glog.Errorf(ctx, "❌ datasource [%s] ping panic: %v", d.config.Name, r) + } + }() + + if d.redis == nil { + d.metrics.PingError.Add(1) + return fmt.Errorf("redis client is nil") + } + + _, err := d.redis.Do(ctx, "PING") + if err != nil { + d.metrics.PingError.Add(1) + return err + } + + d.metrics.PingCount.Add(1) + return nil +} + +func (d *BaseDataSource) Reconnect(ctx context.Context) error { + glog.Infof(ctx, "🔄 reconnecting datasource [%s]", d.config.Name) + return d.Connect(ctx) +} + +func (d *BaseDataSource) Close(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + + if d.redis != nil { + if err := d.redis.Close(ctx); err != nil { + return fmt.Errorf("datasource [%s] close failed: %w", d.config.Name, err) + } + } + + d.isConnected = false + d.redis = nil + glog.Infof(ctx, "datasource [%s] closed", d.config.Name) + return nil +} + +func (d *BaseDataSource) GetMetrics() RedisMetrics { + return d.metrics +} + +// ============================================================================= +// 监控指标 +// ============================================================================= + +type RedisMetrics struct { + PingCount atomic.Int64 + PingError atomic.Int64 + CommandCount atomic.Int64 + CommandError atomic.Int64 +} + +// GetPingMetrics 获取 Ping 相关指标 +func (m *RedisMetrics) GetPingMetrics() (int64, int64) { + return m.PingCount.Load(), m.PingError.Load() +} + +// GetCommandMetrics 获取命令相关指标 +func (m *RedisMetrics) GetCommandMetrics() (int64, int64) { + return m.CommandCount.Load(), m.CommandError.Load() +} + +// ============================================================================= +// 多数据源管理器 +// ============================================================================= + +type DataSourceManager struct { + sources map[string]DataSource + mu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + started bool + maxRetries int + metrics RedisMetrics +} + +var ( + manager *DataSourceManager + once sync.Once +) + +// GetManager 获取全局管理器 +func GetManager() *DataSourceManager { + once.Do(func() { + ctx, cancel := context.WithCancel(context.Background()) + manager = &DataSourceManager{ + sources: make(map[string]DataSource), + ctx: ctx, + cancel: cancel, + started: false, + maxRetries: 3, + } + }) + return manager +} + +// RegisterDataSource 注册数据源 +func (m *DataSourceManager) RegisterDataSource(config *RedisDataSourceConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.sources[config.Name]; exists { + return fmt.Errorf("datasource [%s] already exists", config.Name) + } + + source := NewBaseDataSource(config) + m.sources[config.Name] = source + return nil +} + +// GetDataSource 获取数据源 +func (m *DataSourceManager) GetDataSource(name string) (DataSource, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + source, exists := m.sources[name] + if !exists { + return nil, fmt.Errorf("datasource [%s] not found", name) + } + return source, nil +} + +// GetAllDataSourceNames 获取所有数据源名称 +func (m *DataSourceManager) GetAllDataSourceNames() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + names := make([]string, 0, len(m.sources)) + for name := range m.sources { + names = append(names, name) + } + return names +} + +// GetDefaultDataSource 获取默认数据源(第一个注册的数据源) +func (m *DataSourceManager) GetDefaultDataSource() (DataSource, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + for _, source := range m.sources { + return source, nil + } + return nil, fmt.Errorf("no datasource available") +} + +// GetMetrics 获取全局监控指标 +func (m *DataSourceManager) GetMetrics() RedisMetrics { + return m.metrics +} + +// init 初始化多数据源 +func init() { + ctx := context.Background() + + // 从配置初始化多数据源 + if err := GetManager().InitializeFromConfig(ctx); err != nil { + glog.Errorf(ctx, "❌ Failed to initialize Redis datasources: %v", err) + } else { + glog.Infof(ctx, "✅ Redis datasources initialized: %v", GetManager().GetAllDataSourceNames()) + } + + // 启动健康检查 + GetManager().StartHealthCheck() + + // 设置优雅关闭 + setupGracefulShutdown() +} + +// InitializeFromConfig 从配置初始化数据源 +// 动态读取 config.yml 中 redis 下的所有配置项 +func (m *DataSourceManager) InitializeFromConfig(ctx context.Context) error { + var firstErr error + + // 获取 redis 配置下的所有子键 + redisConfig := g.Cfg().MustGet(ctx, "redis") + if redisConfig.IsNil() { + glog.Warningf(ctx, "no redis configuration found in config.yml") + return nil + } + + // 将配置转换为 map + configMap := redisConfig.Map() + if configMap == nil { + glog.Warningf(ctx, "redis configuration is not a map") + return nil + } + + // 遍历所有 redis 子配置 + for name, subConfig := range configMap { + // 跳过非对象类型的配置 + subMap, ok := subConfig.(map[string]interface{}) + if !ok { + continue + } + + // 检查是否有 address 配置 + address, hasAddress := subMap["address"] + if !hasAddress || gconv.String(address) == "" { + continue + } + + // 构建数据源配置 + config := &RedisDataSourceConfig{ + Name: name, + Address: gconv.String(address), + Db: gconv.Int(subMap["db"]), + Pass: gconv.String(subMap["pass"]), + } + + // 设置默认值 + if config.Db == 0 { + config.Db = 0 + } + if config.Timeout == 0 { + config.Timeout = 10 * time.Second + } + if config.MaxIdle == 0 { + config.MaxIdle = 10 + } + if config.MaxOpen == 0 { + config.MaxOpen = 100 + } + + // 注册数据源 + if err := m.RegisterDataSource(config); err != nil { + glog.Errorf(ctx, "failed to register datasource [%s]: %v", name, err) + if firstErr == nil { + firstErr = err + } + continue + } + + // 连接数据源 + source, _ := m.GetDataSource(name) + if err := source.Connect(ctx); err != nil { + glog.Errorf(ctx, "failed to initialize datasource [%s]: %v", name, err) + if firstErr == nil { + firstErr = err + } + } + } + + return firstErr +} + +// StartHealthCheck 启动健康检查 +func (m *DataSourceManager) StartHealthCheck() { + if m.started { + return + } + m.started = true + go m.healthCheckLoop() +} + +// healthCheckLoop 健康检查循环 +func (m *DataSourceManager) healthCheckLoop() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.checkAndReconnect() + } + } +} + +// checkAndReconnect 检查并重新连接 +func (m *DataSourceManager) checkAndReconnect() { + m.mu.RLock() + defer m.mu.RUnlock() + + for name, source := range m.sources { + if !source.IsConnected() { + glog.Warningf(context.Background(), "datasource [%s] disconnected, attempting reconnect", name) + + reconnectCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := source.Reconnect(reconnectCtx); err != nil { + glog.Errorf(reconnectCtx, "datasource [%s] reconnect failed: %v", name, err) + } else { + glog.Infof(reconnectCtx, "✅ datasource [%s] reconnected successfully", name) + } + } + } +} + +// CloseAll 关闭所有数据源 +func (m *DataSourceManager) CloseAll(ctx context.Context) error { + m.cancel() + + m.mu.RLock() + defer m.mu.RUnlock() + + var lastErr error + for name, source := range m.sources { + if err := source.Close(ctx); err != nil { + glog.Errorf(ctx, "failed to close datasource [%s]: %v", name, err) + lastErr = err + } + } + return lastErr +} + +// setupGracefulShutdown 设置优雅关闭 +func setupGracefulShutdown() { + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + <-sigCh + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + glog.Info(ctx, "🔄 Shutting down Redis connections...") + if err := GetManager().CloseAll(ctx); err != nil { + glog.Errorf(ctx, "❌ Failed to close Redis connections: %v", err) + } else { + glog.Info(ctx, "✅ Redis connections closed successfully") + } + }() +}