88 lines
2.3 KiB
Go
88 lines
2.3 KiB
Go
|
|
package nats
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"github.com/gogf/gf/v2/frame/g"
|
|||
|
|
"github.com/nats-io/nats.go"
|
|||
|
|
"go.opentelemetry.io/otel/trace"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ============ 上下文元数据工具函数 ============
|
|||
|
|
// 以下函数用于在 context 和 NATS 消息头之间互转元数据
|
|||
|
|
|
|||
|
|
// 定义常见的上下文元数据 key
|
|||
|
|
const (
|
|||
|
|
TraceIDKey = "trace_id"
|
|||
|
|
TokenKey = "token"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
func getTraceID(ctx context.Context) (traceID string, err error) {
|
|||
|
|
// 提取 traceId:首先尝试从 OpenTelemetry Span 中提取,从 context 中提取 TraceID
|
|||
|
|
span := trace.SpanFromContext(ctx)
|
|||
|
|
if span != nil && span.SpanContext().HasTraceID() {
|
|||
|
|
traceID = span.SpanContext().TraceID().String()
|
|||
|
|
} else if tid := ctx.Value(TraceIDKey); tid != nil {
|
|||
|
|
traceID = fmt.Sprintf("%v", tid)
|
|||
|
|
}
|
|||
|
|
if traceID == "" {
|
|||
|
|
return traceID, fmt.Errorf("context 中没有 TraceID")
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// contextToHeaders 将 context 中的元数据转换为 NATS 消息头
|
|||
|
|
// 支持提取 user_id、tenant_id、trace_id、token 等常见字段
|
|||
|
|
func contextToHeaders(ctx context.Context) (nats.Header, error) {
|
|||
|
|
headers := make(nats.Header)
|
|||
|
|
|
|||
|
|
// 提取 traceId:首先尝试从 OpenTelemetry Span 中提取
|
|||
|
|
if traceID, err := getTraceID(ctx); err != nil {
|
|||
|
|
return headers, err
|
|||
|
|
} else {
|
|||
|
|
headers.Set(TraceIDKey, traceID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 提取 token(优先级:context value > HTTP Authorization header)
|
|||
|
|
token := ""
|
|||
|
|
if t := ctx.Value(TokenKey); t != nil {
|
|||
|
|
token = fmt.Sprintf("%v", t)
|
|||
|
|
} else if r := g.RequestFromCtx(ctx); r != nil {
|
|||
|
|
// 从 HTTP 请求的 Authorization header 中提取 token
|
|||
|
|
auth := r.GetHeader("Authorization")
|
|||
|
|
if auth != "" {
|
|||
|
|
// 移除 "Bearer " 前缀
|
|||
|
|
if len(auth) > 7 && auth[:7] == "Bearer " {
|
|||
|
|
token = auth[7:]
|
|||
|
|
} else {
|
|||
|
|
token = auth
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if token != "" {
|
|||
|
|
headers.Set(TokenKey, token)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return headers, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// headersToContext 从 NATS 消息头重建 context
|
|||
|
|
// 支持还原 user_id、tenant_id、trace_id、token 等字段
|
|||
|
|
func headersToContext(ctx context.Context, headers nats.Header) context.Context {
|
|||
|
|
if headers == nil {
|
|||
|
|
return ctx
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 恢复 trace_id
|
|||
|
|
if traceID := headers.Get(TraceIDKey); traceID != "" {
|
|||
|
|
ctx = context.WithValue(ctx, TraceIDKey, traceID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 恢复 token
|
|||
|
|
if token := headers.Get(TokenKey); token != "" {
|
|||
|
|
ctx = context.WithValue(ctx, TokenKey, token)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return ctx
|
|||
|
|
}
|