diff --git a/nats/nats_client.go b/nats/nats_client.go new file mode 100644 index 0000000..a4afebb --- /dev/null +++ b/nats/nats_client.go @@ -0,0 +1,313 @@ +package nats + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/gogf/gf/v2/frame/g" + "github.com/nats-io/nats.go" + "github.com/nats-io/nats.go/jetstream" +) + +var ( + nc *nats.Conn + js jetstream.JetStream + inited bool + natsMu sync.RWMutex + natsURL string + healthCtx context.Context + healthCancel context.CancelFunc + connected bool + reconnectChan chan struct{} + + // 连接状态变化监听器 + connStateListeners []connStateListener + connListenersMu sync.RWMutex + + // 监控指标 + metrics metricsCounter +) + +// Metrics 监控指标 +type metricsCounter struct { + PublishCount atomic.Int64 + PublishError atomic.Int64 + SubscribeCount atomic.Int64 + RequestCount atomic.Int64 + RequestError atomic.Int64 + ConsumeCount atomic.Int64 + ConsumeError atomic.Int64 +} + +// ConnState 连接状态 +type connState int + +const ( + connStateDisconnected connState = iota + connStateConnecting + connStateConnected + connStateReconnecting + connStateClosed +) + +// ConnStateListener 连接状态监听器 +type connStateListener func(state connState, err error) + +// GetMetrics 获取监控指标 +func getMetrics() metricsCounter { + return metrics +} + +// registerConnStateListener 注册连接状态监听器 +func registerConnStateListener(listener connStateListener) { + connListenersMu.Lock() + defer connListenersMu.Unlock() + connStateListeners = append(connStateListeners, listener) +} + +// unregisterConnStateListener 取消注册连接状态监听器 +func unregisterConnStateListener(listener connStateListener) { + connListenersMu.Lock() + defer connListenersMu.Unlock() + for i, l := range connStateListeners { + if l != nil && &l == &listener { + connStateListeners = append(connStateListeners[:i], connStateListeners[i+1:]...) + break + } + } +} + +// notifyConnState 通知所有监听器连接状态变化 +func notifyConnState(state connState, err error) { + connListenersMu.RLock() + listeners := make([]connStateListener, len(connStateListeners)) + copy(listeners, connStateListeners) + connListenersMu.RUnlock() + + for _, listener := range listeners { + if listener != nil { + listener(state, err) + } + } +} + +// init 初始化 NATS 连接 +func init() { + // 从配置文件读取 NATS 地址 + natsURL = g.Cfg().MustGet(context.Background(), "nats.url").String() + if natsURL == "" { + // 默认使用本地地址 + natsURL = nats.DefaultURL + } + + // 创建健康检查上下文 + healthCtx, healthCancel = context.WithCancel(context.Background()) + + // 创建重连通知通道(增大缓冲区避免丢失通知) + reconnectChan = make(chan struct{}, 10) + + // 启动连接 + go initConnection() + + // 启动健康检查协程 + go healthCheck() +} + +// initConnection 初始化连接 +func initConnection() { + ctx := context.Background() + notifyConnState(connStateConnecting, nil) + if err := connect(ctx); err != nil { + g.Log().Errorf(ctx, "NATS 初始连接失败: %v", err) + notifyConnState(connStateDisconnected, err) + } +} + +// connect 建立 NATS 连接 +func connect(ctx context.Context) error { + natsMu.Lock() + defer natsMu.Unlock() + + if nc != nil && !nc.IsClosed() { + nc.Close() + } + + // 连接选项配置 + opts := []nats.Option{ + nats.Name("goframe-nats-client"), + nats.ReconnectWait(2 * time.Second), + nats.MaxReconnects(-1), // 无限重连 + nats.PingInterval(10 * time.Second), + nats.MaxPingsOutstanding(5), + nats.ReconnectHandler(func(nc *nats.Conn) { + g.Log().Infof(ctx, "✅ NATS 重连成功: %s", nc.ConnectedUrl()) + connected = true + + // 重新创建 JetStream 实例 + if newJS, err := jetstream.New(nc); err == nil { + js = newJS + } + + // 通知重连成功 + notifyConnState(connStateConnected, nil) + + // 使用非阻塞发送避免阻塞 + select { + case reconnectChan <- struct{}{}: + default: + // 通道已满,丢弃通知 + } + }), + nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { + g.Log().Warningf(ctx, "⚠️ NATS 连接断开: %v, 准备重连...", err) + connected = false + notifyConnState(connStateReconnecting, err) + }), + nats.ClosedHandler(func(nc *nats.Conn) { + g.Log().Infof(ctx, "NATS 连接已关闭: %s", nc.ConnectedUrl()) + connected = false + notifyConnState(connStateClosed, nil) + }), + nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) { + g.Log().Errorf(ctx, "NATS 错误: %v", err) + }), + } + + var err error + nc, err = nats.Connect(natsURL, opts...) + if err != nil { + return fmt.Errorf("NATS 连接失败: %w", err) + } + + // 等待连接就绪 + if nc.Status() != nats.CONNECTED { + select { + case <-time.After(5 * time.Second): + notifyConnState(connStateDisconnected, fmt.Errorf("连接超时")) + return fmt.Errorf("NATS 连接超时") + case <-nc.StatusChanged(nats.CONNECTED): + } + } + + // 创建 JetStream 实例 + js, err = jetstream.New(nc) + if err != nil { + return fmt.Errorf("创建 JetStream 失败: %w", err) + } + + connected = true + inited = true + g.Log().Infof(ctx, "✅ NATS 连接成功: %s", nc.ConnectedUrl()) + notifyConnState(connStateConnected, nil) + return nil +} + +// healthCheck 健康检查协程(仅作为备用检查) +func healthCheck() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-healthCtx.Done(): + return + case <-ticker.C: + natsMu.RLock() + currentConnected := connected + currentConn := nc + natsMu.RUnlock() + + if !currentConnected || currentConn == nil || currentConn.IsClosed() { + // 仅记录日志,不尝试重连(NATS 已有自动重连机制) + g.Log().Warning(context.Background(), "NATS 连接断开,等待 NATS 自动重连...") + } + case <-reconnectChan: + // 重连成功的通知(仅记录日志) + g.Log().Info(context.Background(), "收到重连成功通知") + } + } +} + +// checkConnected 检查连接状态 +func checkConnected() bool { + natsMu.RLock() + defer natsMu.RUnlock() + return connected && nc != nil && !nc.IsClosed() +} + +// getConnState 获取当前连接状态 +func getConnState() connState { + natsMu.RLock() + defer natsMu.RUnlock() + + if nc == nil { + return connStateDisconnected + } + + if nc.IsClosed() { + return connStateClosed + } + + if connected { + return connStateConnected + } + + return connStateDisconnected +} + +// shutdown 优雅关闭:自动注销所有已注册的服务并关闭 NATS 连接 +func shutdown() error { + ctx := context.Background() + g.Log().Info(ctx, "开始优雅关闭 NATS RPC 服务...") + + // 注销所有单实例服务 + rpcServicesMu.Lock() + singleServiceCount := len(rpcServices) + for serviceName := range rpcServices { + if sub, exists := rpcSubs[serviceName]; exists { + if err := sub.Unsubscribe(); err != nil { + g.Log().Errorf(ctx, "注销服务 %s 失败: %v", serviceName, err) + } + } + delete(rpcSubs, serviceName) + delete(rpcServices, serviceName) + } + rpcServicesMu.Unlock() + + // 注销所有队列服务 + queueRPCMu.Lock() + queueServiceCount := 0 + for queueName, servicesMap := range queueRPCServices { + queueServiceCount += len(servicesMap) + for serviceName, sub := range queueRPCSubs[queueName] { + if err := sub.Unsubscribe(); err != nil { + g.Log().Errorf(ctx, "注销队列服务 %s (队列: %s) 失败: %v", serviceName, queueName, err) + } + } + delete(queueRPCSubs, queueName) + delete(queueRPCServices, queueName) + } + queueRPCMu.Unlock() + + g.Log().Infof(ctx, "已注销 %d 个单实例服务和 %d 个队列服务", singleServiceCount, queueServiceCount) + + natsMu.Lock() + defer natsMu.Unlock() + + // 停止健康检查协程 + if healthCancel != nil { + healthCancel() + } + + // 关闭连接 + if nc != nil && !nc.IsClosed() { + nc.Close() + connected = false + inited = false + } + g.Log().Info(ctx, "NATS RPC 服务已优雅关闭") + return nil +} diff --git a/nats/nats_rpc.go b/nats/nats_rpc.go new file mode 100644 index 0000000..e95514b --- /dev/null +++ b/nats/nats_rpc.go @@ -0,0 +1,752 @@ +package nats + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/gogf/gf/v2/frame/g" + "github.com/nats-io/nats.go" + "go.opentelemetry.io/otel/trace" + "reflect" + "sync" +) + +// ============ RPC 服务封装 ============ +// 以下方法提供了完全抽象的 RPC 调用接口 +// 调用方和响应方完全不需要知道底层使用的是 NATS 的发布订阅模式 + +// RPC 服务注册表 +var ( + rpcServices map[string]rpcHandler + rpcSubs map[string]*nats.Subscription // 服务名 -> 订阅 + rpcServicesMu sync.RWMutex + queueRPCServices map[string]map[string]rpcHandler // queueName -> subject -> handler + queueRPCSubs map[string]map[string]*nats.Subscription // queueName -> serviceName -> 订阅 + queueRPCMu sync.RWMutex + + // ============ TraceID 主动取消支持 ============ + // 全局映射表:TraceID -> CancelFunc,并发安全 + traceCancelMap map[string]context.CancelFunc + traceCancelMu sync.RWMutex + // 取消主题前缀 + cancelSubjectPrefix = "ctx.cancel.otel." +) + +// rpcHandler RPC 处理函数类型 +// 实现方只需要关注请求参数和返回值,无需了解底层 NATS 实现 +// 返回值可以是任意类型,会被自动序列化为 JSON +type rpcHandler func(ctx context.Context, req []byte) (any, error) + +// RegisterRPCService 注册 RPC 服务(单实例) +// serviceName: 服务名称,调用方通过此名称调用服务 +// handler: 服务处理函数,接收请求并返回响应 +func registerRPCService(serviceName string, handler rpcHandler) (err error) { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + rpcServicesMu.Lock() + if rpcServices == nil { + rpcServices = make(map[string]rpcHandler) + } + if rpcSubs == nil { + rpcSubs = make(map[string]*nats.Subscription) + } + + // 如果已存在该服务,先取消之前的订阅 + if oldSub, exists := rpcSubs[serviceName]; exists { + oldSub.Unsubscribe() + } + + rpcServices[serviceName] = handler + rpcServicesMu.Unlock() + + // 订阅服务主题 + subject := fmt.Sprintf("rpc.%s", serviceName) + sub, err := nc.Subscribe(subject, func(msg *nats.Msg) { + // 执行处理函数 + executeHandler(handler, msg) + }) + + if err != nil { + return fmt.Errorf("注册 RPC 服务失败: %w", err) + } + + rpcSubs[serviceName] = sub + metrics.SubscribeCount.Add(1) + g.Log().Infof(context.Background(), "✅ RPC 服务已注册: %s", serviceName) + return nil +} + +// RegisterQueueRPCService 注册 RPC 服务(集群模式) +// 多个服务实例注册同一服务时,请求会自动负载均衡 +// serviceName: 服务名称 +// queueName: 队列组名,同一队列组的实例共享请求 +// handler: 服务处理函数 +func registerQueueRPCService(serviceName, queueName string, handler rpcHandler) (err error) { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + queueRPCMu.Lock() + if queueRPCServices == nil { + queueRPCServices = make(map[string]map[string]rpcHandler) + } + if queueRPCSubs == nil { + queueRPCSubs = make(map[string]map[string]*nats.Subscription) + } + if queueRPCServices[queueName] == nil { + queueRPCServices[queueName] = make(map[string]rpcHandler) + } + if queueRPCSubs[queueName] == nil { + queueRPCSubs[queueName] = make(map[string]*nats.Subscription) + } + + // 如果已存在该服务,先取消之前的订阅 + if oldSub, exists := queueRPCSubs[queueName][serviceName]; exists { + oldSub.Unsubscribe() + } + + queueRPCServices[queueName][serviceName] = handler + queueRPCMu.Unlock() + + // 订阅服务主题(队列模式) + subject := fmt.Sprintf("rpc.%s", serviceName) + sub, err := nc.QueueSubscribe(subject, queueName, func(msg *nats.Msg) { + // 执行处理函数 + executeHandler(handler, msg) + }) + + if err != nil { + return fmt.Errorf("注册队列 RPC 服务失败: %w", err) + } + + queueRPCMu.Lock() + queueRPCSubs[queueName][serviceName] = sub + queueRPCMu.Unlock() + + metrics.SubscribeCount.Add(1) + g.Log().Infof(context.Background(), "✅ 队列 RPC 服务已注册: %s (队列组: %s)", serviceName, queueName) + return nil +} + +// executeHandler 执行 RPC 处理函数 +func executeHandler(handler rpcHandler, msg *nats.Msg) { + // 响应 + var respData []byte + // 从消息头重建上下文 + ctx := headersToContext(context.Background(), msg.Header) + // 提取 TraceID,创建可取消的 context + ctx = createCancelContext(ctx, msg.Header.Get(TraceIDKey)) + // 检查 context 是否已取消(在调用 handler 之前) + select { + case <-ctx.Done(): + // context 已取消,返回取消错误 + g.Log().Infof(ctx, "RPC 请求已取消,traceID: %s", msg.Header.Get(TraceIDKey)) + // 仍然需要发送响应以避免客户端超时 + respData = []byte(`{"_err":"请求已取消"}`) + // 清理取消映射表 + cleanupTraceCancel(msg.Header.Get(TraceIDKey)) + return + default: + } + + // 执行业务处理 + response, err := handler(ctx, msg.Data) + + if err != nil { + // 错误时返回 {"_err": "错误信息"} + if respData, err = json.Marshal(map[string]any{"_err": err.Error()}); err != nil { + g.Log().Errorf(ctx, "RPC 错误响应序列化失败: %v", err) + respData = []byte(`{"_err":"错误响应序列化失败"}`) + } + } else if response == nil { + // 空响应时返回空对象(或 {"_err": ""}) + respData = []byte(`{}`) + } else { + // 成功时返回业务数据 + if respData, err = json.Marshal(response); err != nil { + g.Log().Errorf(ctx, "RPC 响应序列化失败: %v", err) + respData = []byte(`{"_err":"响应序列化失败"}`) + } + } + // 发送响应(必须执行) 如果客户端用 nc.Request(...) 发送消息 → 双向模式,服务端必须 msg.Respond + if err = msg.Respond(respData); err != nil { + g.Log().Errorf(ctx, "RPC 响应失败: %v", err) + } + // 请求结束,清理取消映射表 + cleanupTraceCancel(msg.Header.Get(TraceIDKey)) +} + +// createCancelContext 创建可取消的 context 并注册到取消映射表 +// 返回可取消的 context(如果 traceID 为空则返回原 context) +func createCancelContext(ctx context.Context, traceID string) context.Context { + if g.IsEmpty(traceID) { + return ctx + } + // 创建带取消功能的 context + taskCtx, cancel := context.WithCancel(ctx) + // 注册到取消映射表 + traceCancelMu.Lock() + if traceCancelMap == nil { + traceCancelMap = make(map[string]context.CancelFunc) + } + // 如果同一 TraceID 已有 CancelFunc,先调用它 + if oldCancel, exists := traceCancelMap[traceID]; exists { + oldCancel() + } + traceCancelMap[traceID] = cancel + traceCancelMu.Unlock() + + return taskCtx +} + +// ============ TraceID 主动取消功能 ============ +// 以下函数实现了基于 OpenTelemetry TraceID 的跨进程任务取消机制 + +// SetupCancelListener 设置取消监听器 +// 订阅取消主题,监听取消指令 +// 使用示例: +// +// sub, err := nats.SetupCancelListener(ctx) +func setupCancelListener(ctx context.Context) (*nats.Subscription, error) { + if !checkConnected() { + return nil, fmt.Errorf("NATS 未连接") + } + + if traceCancelMap == nil { + traceCancelMap = make(map[string]context.CancelFunc) + } + + // 修复问题3:订阅取消主题,格式: ctx.cancel.otel.* + // 使用 * 通配符而不是 >,因为 TraceID 是最后一部分 + cancelSubject := cancelSubjectPrefix + "*" + sub, err := nc.Subscribe(cancelSubject, func(msg *nats.Msg) { + // 从主题中解析 TraceID (去除前缀) + prefixLen := len(cancelSubjectPrefix) + if len(msg.Subject) <= prefixLen { + g.Log().Warningf(ctx, "取消消息主题格式错误: %s", msg.Subject) + return + } + traceID := msg.Subject[prefixLen:] + + if traceID == "" { + g.Log().Warning(ctx, "取消消息主题缺少 TraceID") + return + } + + // 从映射表获取 CancelFunc 并执行取消 + traceCancelMu.RLock() + cancel, ok := traceCancelMap[traceID] + traceCancelMu.RUnlock() + + if ok { + cancel() + g.Log().Infof(ctx, "📢 取消信号已发送,traceID: %s", traceID) + } else { + g.Log().Infof(ctx, "⚠️ 未找到对应的可取消任务,traceID: %s", traceID) + } + }) + + if err != nil { + return nil, fmt.Errorf("设置取消监听器失败: %w", err) + } + + metrics.SubscribeCount.Add(1) + g.Log().Infof(ctx, "✅ 取消监听器已设置: %s", cancelSubject) + return sub, nil +} + +// publishCancel 发布取消指令 +// 向指定 TraceID 发送取消信号 +// 使用示例: +// +// err := nats.publishCancel(ctx, traceID) +func publishCancel(ctx context.Context, traceID string) error { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + if traceID == "" { + return fmt.Errorf("TraceID 不能为空") + } + + cancelSubject := cancelSubjectPrefix + traceID + err := nc.Publish(cancelSubject, nil) + if err != nil { + return fmt.Errorf("发布取消信号失败: %w", err) + } + + g.Log().Infof(ctx, "📤 已发送取消信号,traceID: %s,主题: %s", traceID, cancelSubject) + return nil +} + +// cleanupTraceCancel 清理取消映射表中的条目 +// 任务取消/正常结束后必须调用此函数,避免内存泄漏 +// 使用示例: +// +// defer nats.cleanupTraceCancel(traceID) +func cleanupTraceCancel(traceID string) { + if traceID == "" { + return + } + + traceCancelMu.Lock() + defer traceCancelMu.Unlock() + + if _, ok := traceCancelMap[traceID]; ok { + delete(traceCancelMap, traceID) + g.Log().Infof(context.Background(), "✅ 已清理取消映射表,traceID: %s", traceID) + } +} + +// CallRPC 调用 RPC 服务 +// serviceName: 服务名称 +// req: 请求数据 +// 返回: 响应数据(任意类型)和错误 +func CallRPC(ctx context.Context, serviceName string, req any, resp any) (err error) { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + metrics.RequestCount.Add(1) + + // 验证 resp 必须是指针类型 + respValue := reflect.ValueOf(resp) + if respValue.Kind() != reflect.Ptr { + return fmt.Errorf("resp 参数必须是指针类型(当前类型: %T)", resp) + } + + // 构建请求体 + var reqBody []byte + if !g.IsEmpty(req) { + reqValue := reflect.ValueOf(req) + if !(reqValue.Kind() == reflect.Ptr && reqValue.IsNil()) && !reqValue.IsZero() { + reqData, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("序列化请求参数失败: %w", err) + } + reqBody = reqData + } + } + + // 检查本地是否有注册的单实例服务,如果有则直接调用(优化性能) + rpcServicesMu.RLock() + if localHandler, exists := rpcServices[serviceName]; exists { + rpcServicesMu.RUnlock() + + // 修复问题1:本地调用也需要处理取消机制 + var traceID string + if traceID, err = getTraceID(ctx); err != nil { + return err + } + // 提取 TraceID,创建可取消的 context + cancelCtx := createCancelContext(ctx, traceID) + // 执行本地调用 + var response interface{} + if response, err = localHandler(cancelCtx, reqBody); err != nil { + metrics.RequestError.Add(1) + return fmt.Errorf("本地调用 RPC 服务失败 [%s]: %w", serviceName, err) + } + + // 请求结束,清理取消映射表 + cleanupTraceCancel(traceID) + + // 检查是否为错误消息:尝试解析为 map,看是否包含 "_err" 字段 + var respMap map[string]any + if json.Unmarshal(response.([]byte), &respMap) == nil { + if errMsg, ok := respMap["_err"]; ok { + metrics.RequestError.Add(1) + return fmt.Errorf("%v", errMsg) + } + } + // 正常数据直接返回 + // responseMsg.Data 已经是 []byte 类型(来自 msg.Data),直接反序列化 + if err = json.Unmarshal(response.([]byte), resp); err != nil { + return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, response) + } + + return + } + rpcServicesMu.RUnlock() + + subject := fmt.Sprintf("rpc.%s", serviceName) + + // 创建消息并将上下文元数据写入消息头 + msg := nats.NewMsg(subject) + msg.Data = reqBody + headers, err := contextToHeaders(ctx) + if err != nil { + return fmt.Errorf("上下文转换失败: %w", err) + } + msg.Header = headers + + // 修复问题5:优化 go 协程避免资源泄漏 + // 使用 done channel 来确保 goroutine 能正确退出 + done := make(chan struct{}) + var closeDoneOnce sync.Once + closeDone := func() { + closeDoneOnce.Do(func() { + close(done) + }) + } + + if msg.Header.Get(TraceIDKey) != "" { + go func() { + defer closeDone() + select { + case <-ctx.Done(): + // context 被取消时,发送取消信号给服务端 + if errors.Is(ctx.Err(), context.Canceled) { + if err := publishCancel(context.Background(), msg.Header.Get(TraceIDKey)); err != nil { + g.Log().Errorf(ctx, "发送 RPC 取消信号失败: %v", err) + } else { + g.Log().Infof(ctx, "RPC 调用已取消,traceID: %s", msg.Header.Get(TraceIDKey)) + } + } + case <-done: + // 请求已完成,无需发送取消信号 + return + } + }() + } + + // 发送请求 + responseMsg, err := nc.RequestMsgWithContext(ctx, msg) + + // 关闭 done channel,通知 goroutine 退出 + closeDone() + + if err != nil { + metrics.RequestError.Add(1) + return fmt.Errorf("调用 RPC 服务失败 [%s]: %w", serviceName, err) + } + + if responseMsg == nil { + metrics.RequestError.Add(1) + return fmt.Errorf("RPC 响应为空 [%s]", serviceName) + } + + // 解析响应 + if len(responseMsg.Data) > 0 { + // 检查是否为错误消息:尝试解析为 map,看是否包含 "_err" 字段 + var respMap map[string]any + if json.Unmarshal(responseMsg.Data, &respMap) == nil { + if errMsg, ok := respMap["_err"]; ok { + metrics.RequestError.Add(1) + return fmt.Errorf("%v", errMsg) + } + } + // 正常数据直接返回 + // responseMsg.Data 已经是 []byte 类型(来自 msg.Data),直接反序列化 + if err = json.Unmarshal(responseMsg.Data, resp); err != nil { + return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, responseMsg.Data) + } + } + + return +} + +// RegisterServiceOption 注册选项类型 +type RegisterServiceOption func(*registerServiceConfig) + +type registerServiceConfig struct { + queueName string // 队列组名(用于集群模式) + excludeMethods []string +} + +// WithQueueGroup 设置队列组名(集群模式) +func WithQueueGroup(queueName string) RegisterServiceOption { + return func(cfg *registerServiceConfig) { + cfg.queueName = queueName + } +} + +// WithExcludeMethods 排除不需要注册的方法 +func WithExcludeMethods(methods ...string) RegisterServiceOption { + return func(cfg *registerServiceConfig) { + cfg.excludeMethods = append(cfg.excludeMethods, methods...) + } +} + +// AutoRegisterServices 自动注册多个服务的所有公开方法 +// serviceInstances: map[包名]service实例,如 map[string]interface{}{"user": userService, "order": orderService} +// options: 注册选项(可选) +// 示例: +// +// AutoRegisterServices(map[string]interface{}{ +// "user": userService, +// "order": orderService, +// }) +// 或 +// AutoRegisterServices(map[string]interface{}{ +// "order": orderService, +// }, WithQueueGroup("order-group")) +func AutoRegisterServices(ctx context.Context, serviceInstances map[string]interface{}, options ...RegisterServiceOption) error { + // 先注册 RPC 服务(如果 NATS 不可用则记录警告但不阻塞启动) + if !checkConnected() { + return fmt.Errorf("NATS 未连接,RPC 服务未注册") + } + + if len(serviceInstances) == 0 { + return fmt.Errorf("service 实例列表不能为空") + } + + totalRegistered := 0 + // 遍历每个 service 实例 + for pkgName, serviceInstance := range serviceInstances { + // 注册服务 + err := registerService(serviceInstance, pkgName, options...) + if err != nil { + g.Log().Errorf(ctx, "注册 %s 服务失败: %v", pkgName, err) + continue + } + totalRegistered++ + g.Log().Infof(ctx, "✅ %s 服务已自动注册", pkgName) + } + + if totalRegistered == 0 { + return fmt.Errorf("未能注册任何服务") + } + // 设置取消监听器(监听基于 TraceID 的取消请求) + if _, err := setupCancelListener(ctx); err != nil { + g.Log().Errorf(ctx, "设置取消监听器失败: %v", err) + } else { + g.Log().Infof(ctx, "✅ 取消监听器已自动设置") + } + g.Log().Infof(ctx, "✅ 共自动注册了 %d 个服务", totalRegistered) + + return nil +} + +// registerService 注册单个服务的所有公开方法(内部函数) +func registerService(service interface{}, serviceNamePrefix string, options ...RegisterServiceOption) (err error) { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + // 应用选项 + cfg := ®isterServiceConfig{} + for _, opt := range options { + opt(cfg) + } + + // 创建排除方法集合 + excludeSet := make(map[string]struct{}) + for _, method := range cfg.excludeMethods { + excludeSet[method] = struct{}{} + } + + // 获取 service 的类型 + serviceType := reflect.TypeOf(service) + + // 遍历所有方法 + registeredCount := 0 + for i := 0; i < serviceType.NumMethod(); i++ { + method := serviceType.Method(i) + + // 只注册导出方法(首字母大写) + if !method.IsExported() { + continue + } + + // 排除指定的方法 + if _, exists := excludeSet[method.Name]; exists { + continue + } + + // 检查方法签名:必须是 func(ctx context.Context, request) (response, error) + // 注意:method.Type.NumIn() 包含接收者,所以实际参数数量需要减去 1 + // 要求:接收者 + context.Context + request,总共3个参数 + if method.Type.NumIn() != 3 { + g.Log().Warningf(context.Background(), "方法 %s 必须有2个参数(context.Context 和请求参数),跳过注册", method.Name) + continue + } + + // 第一个参数(接收者之后的第一个参数)必须是 context.Context + // method.Type.In(0) 是接收者,method.Type.In(1) 才是第一个参数 + if !method.Type.In(1).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { + g.Log().Warningf(context.Background(), "方法 %s 的第一个参数必须是 context.Context,跳过注册", method.Name) + continue + } + + // 第二个参数必须是结构体指针或数组 + reqType := method.Type.In(2) + if reqType.Kind() != reflect.Ptr && reqType.Kind() != reflect.Slice && reqType.Kind() != reflect.Array { + g.Log().Warningf(context.Background(), "方法 %s 的第二个参数必须是结构体指针或数组,跳过注册", method.Name) + continue + } + + // 返回值必须是 (result, error),即2个返回值 + if method.Type.NumOut() != 2 { + g.Log().Warningf(context.Background(), "方法 %s 必须有2个返回值(result 和 error),跳过注册", method.Name) + continue + } + + // 最后一个返回值必须是 error + if !method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + g.Log().Warningf(context.Background(), "方法 %s 的最后一个返回值必须是 error,跳过注册", method.Name) + continue + } + + // 生成服务名称:前缀.方法名(保持原始方法名) + serviceName := fmt.Sprintf("%s.%s", serviceNamePrefix, method.Name) + + // 创建 RPC handler + handler := func(ctx context.Context, req []byte) (any, error) { + // 准备方法调用参数 + // args[0] 是接收者, args[1] 是 ctx, args[2] 是请求参数 + args := make([]reflect.Value, 3) + args[0] = reflect.ValueOf(service) // 接收者 + args[1] = reflect.ValueOf(ctx) // context.Context + + // 解析请求参数 + if len(req) > 0 { + reqValuePtr := reflect.New(reqType) + + // 解析 JSON + if err := json.Unmarshal(req, reqValuePtr.Interface()); err != nil { + // 根据参数类型提供更友好的错误提示 + var typeHint string + if reqType.Kind() == reflect.Ptr { + typeHint = fmt.Sprintf("(期望类型: %s)", reqType.Elem().Name()) + } else { // reflect.Slice 或 reflect.Array + typeHint = fmt.Sprintf("(期望类型: %s,请确保客户端传递的是JSON数组格式)", reqType.String()) + } + return nil, fmt.Errorf("解析请求参数失败%s: %w", typeHint, err) + } + args[2] = reqValuePtr.Elem() + } else { + // 请求为空,创建零值 + args[2] = reflect.Zero(method.Type.In(2)) + } + + // 调用方法 + results := method.Func.Call(args) + + // 处理返回值 + var result any + + if len(results) == 1 { + // 只有 error + if !results[0].IsNil() { + err = results[0].Interface().(error) + } + } else if len(results) == 2 { + // (result, error) + result = results[0].Interface() + if !results[1].IsNil() { + err = results[1].Interface().(error) + } + } + if err != nil { + return nil, err + } + + return result, nil + } + + // 注册 RPC 服务 + var err error + if cfg.queueName != "" { + err = registerQueueRPCService(serviceName, cfg.queueName, handler) + } else { + err = registerRPCService(serviceName, handler) + } + + if err != nil { + g.Log().Errorf(context.Background(), "注册服务 %s 失败: %v", serviceName, err) + continue + } + + registeredCount++ + g.Log().Infof(context.Background(), "✅ 已自动注册 RPC 服务: %s -> %s", serviceName, method.Name) + } + + if registeredCount == 0 { + g.Log().Warningf(context.Background(), "未注册任何方法,请检查 %v 的方法签名", serviceNamePrefix) + return fmt.Errorf("未找到可注册的方法") + } + + g.Log().Infof(context.Background(), "✅ Service %v 共注册了 %d 个 RPC 方法", serviceNamePrefix, registeredCount) + return nil +} + +// ============ 上下文元数据工具函数 ============ +// 以下函数用于在 context 和 NATS 消息头之间互转元数据 + +// 定义常见的上下文元数据 key +const ( + TraceIDKey = "trace_id" + TokenKey = "token" +) + +func getTraceID(ctx context.Context) (traceID string, err error) { + // 提取 traceId:首先尝试从 OpenTelemetry Span 中提取,从 context 中提取 TraceID + span := trace.SpanFromContext(ctx) + if span != nil && span.SpanContext().HasTraceID() { + traceID = span.SpanContext().TraceID().String() + } else if tid := ctx.Value(TraceIDKey); tid != nil { + traceID = fmt.Sprintf("%v", tid) + } + if traceID == "" { + return traceID, fmt.Errorf("context 中没有 TraceID") + } + return +} + +// contextToHeaders 将 context 中的元数据转换为 NATS 消息头 +// 支持提取 user_id、tenant_id、trace_id、token 等常见字段 +func contextToHeaders(ctx context.Context) (nats.Header, error) { + headers := make(nats.Header) + + // 提取 traceId:首先尝试从 OpenTelemetry Span 中提取 + if traceID, err := getTraceID(ctx); err != nil { + return headers, err + } else { + headers.Set(TraceIDKey, traceID) + } + + // 提取 token(优先级:context value > HTTP Authorization header) + token := "" + if t := ctx.Value(TokenKey); t != nil { + token = fmt.Sprintf("%v", t) + } else if r := g.RequestFromCtx(ctx); r != nil { + // 从 HTTP 请求的 Authorization header 中提取 token + auth := r.GetHeader("Authorization") + if auth != "" { + // 移除 "Bearer " 前缀 + if len(auth) > 7 && auth[:7] == "Bearer " { + token = auth[7:] + } else { + token = auth + } + } + } + if token != "" { + headers.Set(TokenKey, token) + } + + return headers, nil +} + +// headersToContext 从 NATS 消息头重建 context +// 支持还原 user_id、tenant_id、trace_id、token 等字段 +func headersToContext(ctx context.Context, headers nats.Header) context.Context { + if headers == nil { + return ctx + } + + // 恢复 trace_id + if traceID := headers.Get(TraceIDKey); traceID != "" { + ctx = context.WithValue(ctx, TraceIDKey, traceID) + } + + // 恢复 token + if token := headers.Get(TokenKey); token != "" { + ctx = context.WithValue(ctx, TokenKey, token) + } + + return ctx +}