Files
common/middleware/circuit_breaker.go
2026-03-12 08:51:25 +08:00

1297 lines
40 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"gitee.com/red-future---jilin-g/common/redis"
"github.com/alibaba/sentinel-golang/api"
"github.com/alibaba/sentinel-golang/core/circuitbreaker"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/ghttp"
)
// CircuitBreakerState 熔断器状态
type CircuitBreakerState string
const (
StateClosed CircuitBreakerState = "closed"
StateOpen CircuitBreakerState = "open"
StateHalfOpen CircuitBreakerState = "halfopen"
)
// 熔断器状态常量用于atomic.Int64
const (
stateClosed int64 = 0
stateOpen int64 = 1
stateHalfOpen int64 = 2
)
// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
Enabled bool
MaxFailures int
Timeout string
TimeoutParsed time.Duration
SuccessStatusCodes []int
SlowRequestThreshold string
SlowRequestThresholdParsed time.Duration
EnableSlidingWindow bool
FailureRateThreshold float64
EnableFallback bool
FallbackMessage string
RequestTimeout int
DistributedTTL int
AdminIPs []string
StatIntervalMs int
MinRequestAmount int
AdminCIDRs []string
HalfOpenMaxRequests int
HalfOpenSuccessThreshold float64
WarmupDuration string
WarmupDurationParsed time.Duration
EnableAdaptiveThreshold bool
AdaptiveMinThreshold float64
AdaptiveMaxThreshold float64
CIDRNetMasks []*net.IPNet
}
// CircuitBreakerMetrics 熔断器指标
type CircuitBreakerMetrics struct {
// 请求统计
TotalRequests atomic.Int64
PassRequests atomic.Int64
BlockRequests atomic.Int64
FailureRequests atomic.Int64
SlowRequests atomic.Int64
// 状态统计
OpenCount atomic.Int64
ClosedCount atomic.Int64
HalfOpenCount atomic.Int64
// 时间戳
LastResetTime atomic.Int64
LastOpenTime atomic.Int64
NextRetryTime atomic.Int64
LastCloseTime atomic.Int64
LastHalfOpenTime atomic.Int64
// 半开状态统计
HalfOpenRequests atomic.Int64
HalfOpenPassed atomic.Int64
HalfOpenFailed atomic.Int64
// 性能指标
TotalResponseTime atomic.Int64 // 总响应时间(纳秒)
MinResponseTime atomic.Int64 // 最小响应时间(纳秒)
MaxResponseTime atomic.Int64 // 最大响应时间(纳秒)
// 窗口统计(用于计算成功率等)
WindowStartTime atomic.Int64 // 统计窗口开始时间
WindowRequests atomic.Int64 // 窗口内请求总数
WindowFailures atomic.Int64 // 窗口内失败数
}
// CircuitBreakerInfo 熔断器信息
type CircuitBreakerInfo struct {
ResourceName string
State atomic.Int64
Config *CircuitBreakerConfig
Metrics *CircuitBreakerMetrics
SuccessCodeMap map[int]bool
CIDRNetMasks []*net.IPNet
AdaptiveThreshold float64
WarmupEndTime int64
}
var (
circuitBreakers sync.Map
circuitBreakerConfigs sync.Map
stateChangeListeners sync.Map
stateChangeListenersRegistered sync.Map
allowedAdminIPsMap map[string]bool
allowedAdminIPsMutex sync.RWMutex
allowedAdminCIDRs []*net.IPNet
allowedAdminCIDRsMutex sync.RWMutex
totalServicesCount atomic.Int64
serviceNamesSlice []string
serviceNamesMutex sync.RWMutex
)
// 默认值常量
const (
defaultMaxFailures = 5
defaultTimeout = "60s"
defaultSlowRequestThreshold = "3s"
defaultStatIntervalMs = 1000
defaultRequestTimeout = 30000
defaultDistributedTTL = 300
defaultHalfOpenMaxRequests = 5
defaultWarmupDuration = "10s"
defaultHalfOpenSuccessThreshold = 0.5
)
// getState 获取熔断器状态
func (cb *CircuitBreakerInfo) getState() CircuitBreakerState {
switch cb.State.Load() {
case stateOpen:
return StateOpen
case stateHalfOpen:
return StateHalfOpen
default:
return StateClosed
}
}
// setState 设置熔断器状态
func (cb *CircuitBreakerInfo) setState(state CircuitBreakerState) CircuitBreakerState {
return cb.setStateWithMetrics(state, true)
}
// setStateWithMetrics 设置熔断器状态并更新指标
func (cb *CircuitBreakerInfo) setStateWithMetrics(state CircuitBreakerState, updateMetrics bool) CircuitBreakerState {
newState := cb.stateToInt64(state)
oldState := cb.State.Swap(newState)
oldStateEnum := cb.int64ToState(oldState)
// 如果状态发生了变化且需要更新指标
if oldStateEnum != state && updateMetrics {
cb.updateStateMetrics(state)
}
return oldStateEnum
}
// init 初始化熔断器信息
func (cb *CircuitBreakerInfo) init() {
cb.State.Store(stateClosed)
cb.Metrics.LastResetTime.Store(time.Now().Unix())
cb.Metrics.LastCloseTime.Store(time.Now().Unix())
cb.Metrics.WindowStartTime.Store(time.Now().Unix())
}
// stateToInt64 将CircuitBreakerState转换为int64状态
func (cb *CircuitBreakerInfo) stateToInt64(state CircuitBreakerState) int64 {
switch state {
case StateOpen:
return stateOpen
case StateHalfOpen:
return stateHalfOpen
default:
return stateClosed
}
}
// int64ToState 将int64状态转换为CircuitBreakerState
func (cb *CircuitBreakerInfo) int64ToState(state int64) CircuitBreakerState {
switch state {
case stateOpen:
return StateOpen
case stateHalfOpen:
return StateHalfOpen
default:
return StateClosed
}
}
// updateStateMetrics 更新状态相关的指标
func (cb *CircuitBreakerInfo) updateStateMetrics(state CircuitBreakerState) {
now := time.Now().Unix()
// 根据新状态更新计数器
switch state {
case StateOpen:
cb.Metrics.OpenCount.Add(1)
cb.Metrics.LastOpenTime.Store(now)
// 设置下一次重试时间
cb.Metrics.NextRetryTime.Store(time.Now().Add(cb.Config.TimeoutParsed).Unix())
case StateClosed:
cb.Metrics.ClosedCount.Add(1)
cb.Metrics.LastCloseTime.Store(now)
case StateHalfOpen:
cb.Metrics.HalfOpenCount.Add(1)
cb.Metrics.LastHalfOpenTime.Store(now)
}
}
// getCircuitBreakerInfoAndConfig 获取熔断器信息和配置
func getCircuitBreakerInfoAndConfig(serviceName string) (*CircuitBreakerInfo, *CircuitBreakerConfig) {
cbInfoVal, ok := circuitBreakers.Load(serviceName)
if !ok {
return nil, nil
}
cbInfo, ok := cbInfoVal.(*CircuitBreakerInfo)
if !ok {
return nil, nil
}
return cbInfo, cbInfo.Config
}
// updateResponseTimeStats 更新响应时间统计
func updateResponseTimeStats(cbInfo *CircuitBreakerInfo, duration time.Duration, config *CircuitBreakerConfig) {
durationNs := duration.Nanoseconds()
cbInfo.Metrics.TotalResponseTime.Add(durationNs)
// 原子更新最小和最大响应时间
atomicUpdateMin(&cbInfo.Metrics.MinResponseTime, durationNs)
atomicUpdateMax(&cbInfo.Metrics.MaxResponseTime, durationNs)
if duration > config.SlowRequestThresholdParsed {
cbInfo.Metrics.SlowRequests.Add(1)
}
}
// formatUnixTime 格式化Unix时间戳
func formatUnixTime(timestamp int64) string {
if timestamp > 0 {
return time.Unix(timestamp, 0).Format("2006-01-02 15:04:05")
}
return ""
}
// InitCircuitBreaker 初始化Sentinel熔断器
func InitCircuitBreaker() error {
ctx := context.Background()
if err := api.InitDefault(); err != nil {
return fmt.Errorf("sentinel初始化失败: %v", err)
}
registerStateChangeListeners()
g.Log().Infof(ctx, "Sentinel熔断器初始化成功")
services := g.Cfg().MustGet(ctx, "circuitBreaker").Map()
serviceNames := filterServiceNames(services)
if len(serviceNames) == 0 {
g.Log().Infof(ctx, "未配置任何服务熔断器")
return nil
}
totalServicesCount.Store(int64(len(serviceNames)))
serviceNamesMutex.Lock()
serviceNamesSlice = serviceNames
serviceNamesMutex.Unlock()
enabledCount := 0
for _, serviceName := range serviceNames {
config := loadServiceCircuitBreakerConfig(serviceName)
if config != nil && config.Enabled {
circuitBreakerConfigs.Store(serviceName, config)
if err := initServiceCircuitBreaker(serviceName, config); err != nil {
g.Log().Errorf(ctx, "服务 %s 熔断器初始化失败: %v", serviceName, err)
} else {
g.Log().Infof(ctx, "服务 %s 熔断器初始化成功", serviceName)
enabledCount++
}
}
}
updateAdminIPsCache()
g.Log().Infof(ctx, "共初始化 %d 个服务熔断器,其中 %d 个已启用", len(serviceNames), enabledCount)
return nil
}
// ReloadCircuitBreakerConfig 动态重新加载熔断器配置
func ReloadCircuitBreakerConfig(serviceName string) error {
config := loadServiceCircuitBreakerConfig(serviceName)
if config == nil {
return fmt.Errorf("未找到服务 %s 的配置", serviceName)
}
if err := validateCircuitBreakerConfig(config); err != nil {
return fmt.Errorf("配置验证失败: %v", err)
}
oldConfig, _ := circuitBreakerConfigs.Load(serviceName)
circuitBreakerConfigs.Store(serviceName, config)
if err := initServiceCircuitBreaker(serviceName, config); err != nil {
if oldConfig != nil {
circuitBreakerConfigs.Store(serviceName, oldConfig)
}
return fmt.Errorf("重新初始化熔断器失败: %v", err)
}
g.Log().Infof(context.Background(), "服务 %s 熔断器配置重新加载成功", serviceName)
return nil
}
// loadServiceCircuitBreakerConfig 加载配置
func loadServiceCircuitBreakerConfig(serviceName string) *CircuitBreakerConfig {
ctx := context.Background()
key := "circuitBreaker." + serviceName
config := &CircuitBreakerConfig{
Enabled: g.Cfg().MustGet(ctx, key+".enabled", true).Bool(),
MaxFailures: g.Cfg().MustGet(ctx, key+".maxFailures", defaultMaxFailures).Int(),
Timeout: g.Cfg().MustGet(ctx, key+".timeout", defaultTimeout).String(),
SlowRequestThreshold: g.Cfg().MustGet(ctx, key+".slowRequestThreshold", defaultSlowRequestThreshold).String(),
EnableSlidingWindow: g.Cfg().MustGet(ctx, key+".enableSlidingWindow", false).Bool(),
FailureRateThreshold: g.Cfg().MustGet(ctx, key+".failureRateThreshold", 0.5).Float64(),
EnableFallback: g.Cfg().MustGet(ctx, key+".enableFallback", false).Bool(),
FallbackMessage: g.Cfg().MustGet(ctx, key+".fallbackMessage", "").String(),
RequestTimeout: g.Cfg().MustGet(ctx, key+".requestTimeout", defaultRequestTimeout).Int(),
DistributedTTL: g.Cfg().MustGet(ctx, key+".distributedTTL", defaultDistributedTTL).Int(),
StatIntervalMs: g.Cfg().MustGet(ctx, key+".statIntervalMs", defaultStatIntervalMs).Int(),
HalfOpenMaxRequests: g.Cfg().MustGet(ctx, key+".halfOpenMaxRequests", defaultHalfOpenMaxRequests).Int(),
HalfOpenSuccessThreshold: g.Cfg().MustGet(ctx, key+".halfOpenSuccessThreshold", defaultHalfOpenSuccessThreshold).Float64(),
WarmupDuration: g.Cfg().MustGet(ctx, key+".warmupDuration", defaultWarmupDuration).String(),
EnableAdaptiveThreshold: g.Cfg().MustGet(ctx, key+".enableAdaptiveThreshold", false).Bool(),
AdaptiveMinThreshold: g.Cfg().MustGet(ctx, key+".adaptiveMinThreshold", 0.3).Float64(),
AdaptiveMaxThreshold: g.Cfg().MustGet(ctx, key+".adaptiveMaxThreshold", 0.7).Float64(),
}
config.MinRequestAmount = g.Cfg().MustGet(ctx, key+".minRequestAmount", 0).Int()
if config.MinRequestAmount == 0 {
config.MinRequestAmount = config.MaxFailures
}
// 解析时间 - 使用默认值处理解析错误
config.TimeoutParsed, config.Timeout = parseDurationWithDefault(ctx, config.Timeout, defaultTimeout, "timeout")
config.SlowRequestThresholdParsed, config.SlowRequestThreshold = parseDurationWithDefault(ctx, config.SlowRequestThreshold, defaultSlowRequestThreshold, "slowRequestThreshold")
config.WarmupDurationParsed, config.WarmupDuration = parseDurationWithDefault(ctx, config.WarmupDuration, defaultWarmupDuration, "warmupDuration")
// 解析状态码
successCodes := g.Cfg().MustGet(ctx, key+".successStatusCodes", "200,201,204").String()
config.SuccessStatusCodes = parseIntSlice(successCodes)
// 解析IP和CIDR
config.AdminIPs = parseStrings(g.Cfg().MustGet(ctx, key+".adminIPs", "").String())
config.AdminCIDRs = parseStrings(g.Cfg().MustGet(ctx, key+".adminCIDRs", "").String())
config.CIDRNetMasks, _ = parseCIDRs(config.AdminCIDRs)
return config
}
// parseIntSlice 解析整数切片
func parseIntSlice(str string) []int {
parts := strings.Split(str, ",")
result := make([]int, 0, len(parts))
for _, part := range parts {
if val, err := strconv.Atoi(strings.TrimSpace(part)); err == nil {
result = append(result, val)
}
}
return result
}
// parseStrings 解析字符串切片
func parseStrings(str string) []string {
if str == "" {
return nil
}
parts := strings.Split(str, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
// parseDurationWithDefault 解析持续时间,失败时使用默认值
func parseDurationWithDefault(ctx context.Context, durationStr, defaultStr, fieldName string) (time.Duration, string) {
durationParsed, err := time.ParseDuration(durationStr)
if err != nil {
g.Log().Warningf(ctx, "解析%s失败: %s, 使用默认值 %s, error: %v", fieldName, durationStr, defaultStr, err)
durationParsed, _ = time.ParseDuration(defaultStr)
return durationParsed, defaultStr
}
return durationParsed, durationStr
}
// atomicUpdateMin 原子更新最小值
func atomicUpdateMin(minValue *atomic.Int64, newValue int64) {
for {
currentMin := minValue.Load()
if newValue >= currentMin {
break
}
if minValue.CompareAndSwap(currentMin, newValue) {
break
}
}
}
// atomicUpdateMax 原子更新最大值
func atomicUpdateMax(maxValue *atomic.Int64, newValue int64) {
for {
currentMax := maxValue.Load()
if newValue <= currentMax {
break
}
if maxValue.CompareAndSwap(currentMax, newValue) {
break
}
}
}
// getAllowedIPsAndCIDRs 获取允许的IP和CIDR列表带锁保护
func getAllowedIPsAndCIDRs() (map[string]bool, []*net.IPNet) {
allowedAdminIPsMutex.RLock()
allowedAdminCIDRsMutex.RLock()
defer allowedAdminIPsMutex.RUnlock()
defer allowedAdminCIDRsMutex.RUnlock()
return allowedAdminIPsMap, allowedAdminCIDRs
}
// getAllowedIPs 获取允许的IP列表带锁保护兼容旧代码
func getAllowedIPs() map[string]bool {
allowedAdminIPsMutex.RLock()
defer allowedAdminIPsMutex.RUnlock()
return allowedAdminIPsMap
}
// getAllowedCIDRs 获取允许的CIDR列表带锁保护兼容旧代码
func getAllowedCIDRs() []*net.IPNet {
allowedAdminCIDRsMutex.RLock()
defer allowedAdminCIDRsMutex.RUnlock()
return allowedAdminCIDRs
}
// reset 重置所有指标到初始状态
func (m *CircuitBreakerMetrics) reset() {
m.TotalRequests.Store(0)
m.PassRequests.Store(0)
m.BlockRequests.Store(0)
m.FailureRequests.Store(0)
m.SlowRequests.Store(0)
m.OpenCount.Store(0)
m.HalfOpenRequests.Store(0)
m.HalfOpenPassed.Store(0)
m.HalfOpenFailed.Store(0)
m.TotalResponseTime.Store(0)
m.MinResponseTime.Store(1<<63 - 1) // 最大int64值作为初始最小值
m.MaxResponseTime.Store(0)
m.WindowRequests.Store(0)
m.WindowFailures.Store(0)
// 时间戳相关字段不重置LastResetTime在调用时单独设置
}
// parseCIDRs 解析CIDR列表
func parseCIDRs(strs []string) ([]*net.IPNet, error) {
nets := make([]*net.IPNet, 0, len(strs))
for _, s := range strs {
if s == "*" {
if _, ipv4Net, err := net.ParseCIDR("0.0.0.0/0"); err == nil {
nets = append(nets, ipv4Net)
}
if _, ipv6Net, err := net.ParseCIDR("::/0"); err == nil {
nets = append(nets, ipv6Net)
}
continue
}
if _, ipNet, err := net.ParseCIDR(s); err == nil {
nets = append(nets, ipNet)
}
}
return nets, nil
}
// newCircuitBreakerMetrics 创建并初始化熔断器指标
func newCircuitBreakerMetrics() *CircuitBreakerMetrics {
metrics := &CircuitBreakerMetrics{}
metrics.reset()
return metrics
}
// updateWindowStats 更新窗口统计信息
func (cb *CircuitBreakerInfo) updateWindowStats(isSuccess bool, ctx context.Context) {
now := time.Now().Unix()
windowStart := cb.Metrics.WindowStartTime.Load()
// 默认窗口大小为60秒
windowSize := int64(60)
// 如果超过窗口大小,重置统计
if now-windowStart >= windowSize {
// 使用原子操作重置窗口
if cb.Metrics.WindowStartTime.CompareAndSwap(windowStart, now) {
cb.Metrics.WindowRequests.Store(0)
cb.Metrics.WindowFailures.Store(0)
}
// 重新获取最新的windowStart
windowStart = cb.Metrics.WindowStartTime.Load()
}
// 原子更新窗口内请求总数
cb.Metrics.WindowRequests.Add(1)
if !isSuccess {
cb.Metrics.WindowFailures.Add(1)
}
// 计算当前窗口内的成功率
total := cb.Metrics.WindowRequests.Load()
failures := cb.Metrics.WindowFailures.Load()
if total > 0 {
successRate := float64(total-failures) / float64(total)
if successRate < 0.5 && total >= 10 { // 如果成功率低于50%且有足够样本
g.Log().Warningf(ctx, "熔断器 %s 窗口内成功率较低: %.2f%%, total=%d, failures=%d",
cb.ResourceName, successRate*100, total, failures)
}
}
}
// validateInRange 验证值是否在指定范围内
func validateInRange(name string, value, min, max int) error {
if value < min || value > max {
return fmt.Errorf("%s必须在%d-%d之间", name, min, max)
}
return nil
}
// validateFloatInRange 验证浮点数值是否在指定范围内
func validateFloatInRange(name string, value, min, max float64) error {
if value < min || value > max {
return fmt.Errorf("%s必须在%.1f-%.1f之间", name, min, max)
}
return nil
}
// validateCircuitBreakerConfig 验证配置
func validateCircuitBreakerConfig(config *CircuitBreakerConfig) error {
if config.MaxFailures <= 0 {
return fmt.Errorf("maxFailures必须大于0")
}
if err := validateFloatInRange("failureRateThreshold", config.FailureRateThreshold, 0.0, 1.0); err != nil {
return err
}
if len(config.SuccessStatusCodes) == 0 {
return fmt.Errorf("successStatusCodes不能为空")
}
if err := validateInRange("requestTimeout", config.RequestTimeout, 0, 300000); err != nil {
return err
}
if err := validateInRange("distributedTTL", config.DistributedTTL, 0, 3600); err != nil {
return err
}
if err := validateInRange("statIntervalMs", config.StatIntervalMs, 100, 60000); err != nil {
return err
}
if err := validateInRange("minRequestAmount", config.MinRequestAmount, 1, 10000); err != nil {
return err
}
if err := validateInRange("halfOpenMaxRequests", config.HalfOpenMaxRequests, 1, 100); err != nil {
return err
}
if err := validateFloatInRange("halfOpenSuccessThreshold", config.HalfOpenSuccessThreshold, 0.0, 1.0); err != nil {
return err
}
if config.EnableAdaptiveThreshold {
if err := validateFloatInRange("adaptiveMinThreshold", config.AdaptiveMinThreshold, 0.0, 1.0); err != nil {
return err
}
if err := validateFloatInRange("adaptiveMaxThreshold", config.AdaptiveMaxThreshold, 0.0, 1.0); err != nil {
return err
}
if config.AdaptiveMinThreshold >= config.AdaptiveMaxThreshold {
return fmt.Errorf("adaptiveMinThreshold必须小于adaptiveMaxThreshold")
}
}
return nil
}
// initServiceCircuitBreaker 初始化服务熔断器
func initServiceCircuitBreaker(serviceName string, config *CircuitBreakerConfig) error {
if err := validateCircuitBreakerConfig(config); err != nil {
return err
}
resourceName := "service:" + serviceName
threshold := config.FailureRateThreshold
if config.EnableAdaptiveThreshold {
threshold = (config.AdaptiveMinThreshold + config.AdaptiveMaxThreshold) / 2
}
var rule []*circuitbreaker.Rule
if config.EnableSlidingWindow {
rule = []*circuitbreaker.Rule{{
Resource: resourceName,
Strategy: circuitbreaker.SlowRequestRatio,
RetryTimeoutMs: uint32(config.TimeoutParsed.Milliseconds()),
MinRequestAmount: uint64(config.MinRequestAmount),
StatIntervalMs: uint32(config.StatIntervalMs),
StatSlidingWindowBucketCount: 10,
MaxAllowedRtMs: uint64(config.SlowRequestThresholdParsed.Milliseconds()),
Threshold: threshold,
}}
} else {
rule = []*circuitbreaker.Rule{{
Resource: resourceName,
Strategy: circuitbreaker.ErrorCount,
RetryTimeoutMs: uint32(config.TimeoutParsed.Milliseconds()),
MinRequestAmount: uint64(config.MinRequestAmount),
StatIntervalMs: uint32(config.StatIntervalMs),
Threshold: float64(config.MaxFailures),
}}
}
if _, err := circuitbreaker.LoadRulesOfResource(resourceName, []*circuitbreaker.Rule{}); err != nil {
return fmt.Errorf("清空熔断规则失败: %v", err)
}
if _, err := circuitbreaker.LoadRules(rule); err != nil {
return fmt.Errorf("加载熔断规则失败: %v", err)
}
successCodeMap := make(map[int]bool, len(config.SuccessStatusCodes))
for _, code := range config.SuccessStatusCodes {
successCodeMap[code] = true
}
cbInfo := &CircuitBreakerInfo{
ResourceName: resourceName,
Config: config,
Metrics: newCircuitBreakerMetrics(),
SuccessCodeMap: successCodeMap,
CIDRNetMasks: config.CIDRNetMasks,
AdaptiveThreshold: threshold,
WarmupEndTime: time.Now().Add(config.WarmupDurationParsed).Unix(),
}
cbInfo.init()
circuitBreakers.Store(serviceName, cbInfo)
strategy := "error_count"
if config.EnableSlidingWindow {
strategy = "slow_ratio"
}
g.Log().Infof(context.Background(), "服务 %s 熔断器初始化成功: resource=%s, strategy=%s, timeout=%v, threshold=%.2f",
serviceName, resourceName, strategy, config.TimeoutParsed, rule[0].Threshold)
return nil
}
// CircuitBreakerMiddleware 熔断降级中间件
func CircuitBreakerMiddleware(r *ghttp.Request) {
startTime := time.Now()
ctx := r.GetCtx()
serviceName := extractServiceName(r.URL.Path)
if serviceName == "" {
r.Middleware.Next()
return
}
cbInfo, config := getCircuitBreakerInfoAndConfig(serviceName)
if cbInfo == nil || config == nil || !config.Enabled {
r.Middleware.Next()
return
}
cbInfo.Metrics.TotalRequests.Add(1)
// 预热期检查
if time.Now().Unix() < cbInfo.WarmupEndTime {
r.Middleware.Next()
return
}
resourceName := cbInfo.ResourceName
if config.RequestTimeout > 0 {
var ctxCancel context.CancelFunc
ctx, ctxCancel = context.WithTimeout(ctx, time.Duration(config.RequestTimeout)*time.Millisecond)
r.SetCtx(ctx)
defer ctxCancel()
}
// 分布式熔断检查
if config.DistributedTTL > 0 && isCircuitBreakerOpenInDistributed(ctx, resourceName) {
cbInfo.Metrics.BlockRequests.Add(1)
g.Log().Warningf(ctx, "分布式熔断触发: %s", resourceName)
sendFallbackResponse(r, serviceName, config, "distributed")
return
}
// 半开状态处理 - 使用原子操作确保线程安全
currentState := cbInfo.getState()
if currentState == StateHalfOpen {
// 使用原子操作安全地递增半开请求计数
halfOpenRequests := cbInfo.Metrics.HalfOpenRequests.Add(1)
// 如果超过最大半开请求数量,回滚并触发熔断
if halfOpenRequests > int64(config.HalfOpenMaxRequests) {
// 原子递减回滚
cbInfo.Metrics.HalfOpenRequests.Add(-1)
cbInfo.Metrics.BlockRequests.Add(1)
// 尝试转换为打开状态,如果成功则记录日志
oldState := cbInfo.setState(StateOpen)
if oldState != StateOpen {
g.Log().Warningf(ctx, "半开状态试探请求超限,恢复熔断: %s", resourceName)
if config.DistributedTTL > 0 {
syncCircuitBreakerStateToDistributed(ctx, resourceName, "open", config.DistributedTTL)
}
}
sendFallbackResponse(r, serviceName, config, "halfopen_limit")
return
}
}
entry, blockError := api.Entry(resourceName)
if blockError != nil {
if entry != nil {
entry.Exit()
}
cbInfo.Metrics.BlockRequests.Add(1)
oldState := cbInfo.setStateWithMetrics(StateOpen, true)
if oldState != StateOpen {
notifyStateChange(serviceName, oldState, StateOpen)
}
if config.DistributedTTL > 0 {
syncCircuitBreakerStateToDistributed(ctx, resourceName, "open", config.DistributedTTL)
}
sendFallbackResponse(r, serviceName, config, "blocked")
return
}
if entry != nil {
defer entry.Exit()
}
r.Middleware.Next()
statusCode := r.Response.Status
if statusCode < 100 || statusCode > 599 {
return
}
duration := time.Since(startTime)
// 记录响应时间统计
updateResponseTimeStats(cbInfo, duration, config)
isSuccess := isSuccessStatusCode(cbInfo, statusCode)
// 更新窗口统计
cbInfo.updateWindowStats(isSuccess, ctx)
if !isSuccess {
cbInfo.Metrics.FailureRequests.Add(1)
if entry != nil {
api.TraceError(entry, fmt.Errorf("request failed with status: %d", statusCode))
}
g.Log().Debugf(ctx, "服务 %s 请求失败: status=%d, duration=%v", serviceName, statusCode, duration)
// 重新获取当前状态,避免使用过期状态
currentState := cbInfo.getState()
if currentState == StateHalfOpen {
cbInfo.Metrics.HalfOpenFailed.Add(1)
oldState := cbInfo.setStateWithMetrics(StateOpen, true)
if oldState == StateHalfOpen {
g.Log().Warningf(ctx, "半开状态请求失败,恢复熔断: %s", resourceName)
if config.DistributedTTL > 0 {
syncCircuitBreakerStateToDistributed(ctx, resourceName, "open", config.DistributedTTL)
}
}
}
} else {
cbInfo.Metrics.PassRequests.Add(1)
// 重新获取当前状态
currentState := cbInfo.getState()
if currentState == StateHalfOpen {
// 原子递增成功计数
halfOpenPassed := cbInfo.Metrics.HalfOpenPassed.Add(1)
totalRequests := cbInfo.Metrics.HalfOpenRequests.Load()
// 计算成功率,确保分母不为零
if totalRequests > 0 {
successRate := float64(halfOpenPassed) / float64(totalRequests)
// 检查是否达到成功率阈值,如果达到则关闭熔断器
if successRate >= config.HalfOpenSuccessThreshold {
// 原子设置状态为关闭确保只有一个goroutine能成功转换
oldState := cbInfo.setStateWithMetrics(StateClosed, true)
if oldState == StateHalfOpen {
// 重置半开统计
cbInfo.Metrics.HalfOpenPassed.Store(0)
cbInfo.Metrics.HalfOpenRequests.Store(0)
cbInfo.Metrics.HalfOpenFailed.Store(0)
g.Log().Infof(ctx, "半开状态成功,恢复关闭: %s, successRate=%.2f, total=%d, passed=%d",
resourceName, successRate, totalRequests, halfOpenPassed)
// 同步分布式状态
if config.DistributedTTL > 0 {
syncCircuitBreakerStateToDistributed(ctx, resourceName, "closed", config.DistributedTTL)
}
}
}
}
} else if currentState != StateClosed {
// 如果状态不是关闭但也不是半开,尝试重置为关闭状态
oldState := cbInfo.setStateWithMetrics(StateClosed, true)
if oldState != StateClosed {
notifyStateChange(serviceName, oldState, StateClosed)
}
}
}
}
// sendFallbackResponse 发送降级响应
func sendFallbackResponse(r *ghttp.Request, serviceName string, config *CircuitBreakerConfig, reason string) {
g.Log().Warningf(r.GetCtx(), "熔断器降级: service=%s, reason=%s, clientIP=%s", serviceName, reason, r.GetClientIp())
if config.EnableFallback && config.FallbackMessage != "" {
r.Response.WriteStatusExit(503, config.FallbackMessage)
return
}
switch reason {
case "blocked":
r.Response.WriteStatusExit(503, fmt.Sprintf("服务 '%s' 熔断保护中,请稍后再试", serviceName))
case "distributed":
r.Response.WriteStatusExit(503, fmt.Sprintf("服务 '%s' 分布式熔断中", serviceName))
default:
r.Response.WriteStatusExit(503, fmt.Sprintf("服务 '%s' 暂时不可用,请稍后再试", serviceName))
}
}
// isSuccessStatusCode 判断HTTP状态码是否成功
func isSuccessStatusCode(cbInfo *CircuitBreakerInfo, statusCode int) bool {
// 验证状态码范围
if statusCode < 100 || statusCode > 599 {
return false
}
if len(cbInfo.SuccessCodeMap) > 0 {
return cbInfo.SuccessCodeMap[statusCode]
}
return statusCode >= 200 && statusCode < 300
}
// extractServiceName 从URL路径提取服务名
func extractServiceName(path string) string {
path = strings.Trim(path, "/")
if path == "" {
return ""
}
parts := strings.Split(path, "/")
if len(parts) == 0 {
return ""
}
serviceName := parts[0]
if strings.Contains(serviceName, "%") {
if decoded, err := pathUnescape(serviceName); err == nil {
serviceName = decoded
}
}
if _, ok := circuitBreakerConfigs.Load(serviceName); ok {
return serviceName
}
return ""
}
// pathUnescape 路径片段的URL解码
func pathUnescape(s string) (string, error) {
var builder strings.Builder
builder.Grow(len(s))
for i := 0; i < len(s); i++ {
switch s[i] {
case '%':
if i+2 >= len(s) {
builder.WriteByte(s[i])
continue
}
high := hexDigit(s[i+1])
low := hexDigit(s[i+2])
if high == 0xFF || low == 0xFF {
builder.WriteByte(s[i])
} else {
builder.WriteByte((high << 4) | low)
i += 2
}
default:
builder.WriteByte(s[i])
}
}
return builder.String(), nil
}
func hexDigit(c byte) byte {
switch {
case '0' <= c && c <= '9':
return c - '0'
case 'a' <= c && c <= 'f':
return c - 'a' + 10
case 'A' <= c && c <= 'F':
return c - 'A' + 10
default:
return 0xFF
}
}
// updateAdminIPsCache 更新管理员IP白名单缓存
func updateAdminIPsCache() {
ipMap := make(map[string]bool)
cidrNets := make([]*net.IPNet, 0)
circuitBreakerConfigs.Range(func(_, value interface{}) bool {
config, ok := value.(*CircuitBreakerConfig)
if !ok {
return true
}
for _, ip := range config.AdminIPs {
if !ipMap[ip] {
ipMap[ip] = true
}
}
cidrNets = append(cidrNets, config.CIDRNetMasks...)
return true
})
allowedAdminIPsMutex.Lock()
allowedAdminIPsMap = ipMap
allowedAdminIPsMutex.Unlock()
allowedAdminCIDRsMutex.Lock()
allowedAdminCIDRs = cidrNets
allowedAdminCIDRsMutex.Unlock()
}
// filterServiceNames 过滤服务名
func filterServiceNames(services map[string]interface{}) []string {
excludeKeys := map[string]bool{"services": true, "enableDistributed": true, "requestTimeout": true, "distributedTTL": true}
result := make([]string, 0, len(services))
for key := range services {
if !excludeKeys[key] {
result = append(result, key)
}
}
return result
}
// isCircuitBreakerOpenInDistributed 检查分布式熔断状态
func isCircuitBreakerOpenInDistributed(ctx context.Context, resourceName string) bool {
key := "circuit_breaker:" + resourceName + ":state"
redis := g.Redis()
if redis == nil {
return false
}
value, err := redis.Get(ctx, key)
if err != nil || value.IsNil() {
return false
}
return value.String() == "open"
}
// syncCircuitBreakerStateToDistributed 同步熔断器状态到Redis
func syncCircuitBreakerStateToDistributed(ctx context.Context, resourceName, state string, ttl int) {
stateKey := "circuit_breaker:" + resourceName + ":state"
lockKey := "circuit_breaker:" + resourceName + ":lock"
redisClient := g.Redis()
if redisClient == nil {
g.Log().Warningf(ctx, "Redis未初始化无法同步分布式熔断状态: %s", resourceName)
return
}
// 使用common/redis中的Lock方法获取分布式锁
success, err := redis.Lock(ctx, lockKey, 10, func(ctx context.Context) error {
// 设置熔断器状态
_, err := redisClient.Do(ctx, "SETEX", stateKey, ttl, state)
if err != nil {
g.Log().Errorf(ctx, "设置分布式熔断状态失败: %s=%s, error: %v", stateKey, state, err)
} else {
g.Log().Debugf(ctx, "分布式熔断状态已同步: %s=%s (TTL: %d)", stateKey, state, ttl)
}
return nil
})
if err != nil {
g.Log().Errorf(ctx, "获取分布式锁失败: %s, error: %v", lockKey, err)
return
}
if !success {
g.Log().Debugf(ctx, "未获取到分布式锁,跳过状态同步: %s", lockKey)
}
}
// CircuitBreakerHealthCheckHandler 健康检查接口
func CircuitBreakerHealthCheckHandler(r *ghttp.Request) {
if !isAdminIP(r) {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 403, Message: "权限不足,禁止访问"})
return
}
page := r.Get("page").Int()
size := r.Get("size").Int()
if page < 0 {
page = 0
}
if size <= 0 || size > 100 {
size = 20
}
serviceNamesMutex.RLock()
slice := serviceNamesSlice
serviceNamesMutex.RUnlock()
total := len(slice)
start := page * size
if start >= total {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: "熔断器状态",
Data: map[string]interface{}{
"summary": map[string]interface{}{"totalServices": 0, "openServices": 0, "closedServices": 0, "halfOpenServices": 0},
"services": map[string]interface{}{}, "page": page, "size": size, "total": total}})
return
}
end := start + size
if end > total {
end = total
}
status := make(map[string]interface{})
totalServices := 0
openServices := 0
halfOpenServices := 0
for i := start; i < end; i++ {
serviceName := slice[i]
cbInfoVal, ok := circuitBreakers.Load(serviceName)
if !ok {
continue
}
cbInfo, ok := cbInfoVal.(*CircuitBreakerInfo)
if !ok {
continue
}
totalServices++
state := cbInfo.getState()
if state == StateOpen {
openServices++
} else if state == StateHalfOpen {
halfOpenServices++
}
// 格式化时间字符串
lastResetTimeStr := formatUnixTime(cbInfo.Metrics.LastResetTime.Load())
lastOpenTimeStr := formatUnixTime(cbInfo.Metrics.LastOpenTime.Load())
nextRetryTimeStr := formatUnixTime(cbInfo.Metrics.NextRetryTime.Load())
status[serviceName] = map[string]interface{}{
"resource": cbInfo.ResourceName,
"state": string(state),
"lastOpenTime": lastOpenTimeStr,
"nextRetryTime": nextRetryTimeStr,
"totalRequests": cbInfo.Metrics.TotalRequests.Load(),
"passRequests": cbInfo.Metrics.PassRequests.Load(),
"blockRequests": cbInfo.Metrics.BlockRequests.Load(),
"failureRequests": cbInfo.Metrics.FailureRequests.Load(),
"slowRequests": cbInfo.Metrics.SlowRequests.Load(),
"openCount": cbInfo.Metrics.OpenCount.Load(),
"lastResetTime": lastResetTimeStr,
"halfOpenRequests": cbInfo.Metrics.HalfOpenRequests.Load(),
"halfOpenPassed": cbInfo.Metrics.HalfOpenPassed.Load(),
}
}
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: "熔断器状态",
Data: map[string]interface{}{
"summary": map[string]interface{}{"totalServices": totalServices, "openServices": openServices, "closedServices": totalServices - openServices - halfOpenServices, "halfOpenServices": halfOpenServices},
"services": status, "page": page, "size": size, "total": total}})
}
// isAdminIP 检查IP是否在白名单中
func isAdminIP(r *ghttp.Request) bool {
clientIP := r.GetClientIp()
if clientIP == "" {
return false
}
// 一次性获取IP和CIDR列表减少锁操作
allowedIPs, allowedCIDRs := getAllowedIPsAndCIDRs()
// 如果没有任何限制,允许访问
if len(allowedIPs) == 0 && len(allowedCIDRs) == 0 {
return true
}
// 检查IP白名单
if allowedIPs[clientIP] {
return true
}
// 检查CIDR白名单
if clientNetIP := net.ParseIP(clientIP); clientNetIP != nil {
for _, cidrNet := range allowedCIDRs {
if cidrNet.Contains(clientNetIP) {
return true
}
}
}
g.Log().Warningf(r.GetCtx(), "熔断器操作请求被拒绝IP不在白名单中: %s", clientIP)
return false
}
// batchProcessServices 批量处理服务
func batchProcessServices(r *ghttp.Request, processFunc func(serviceName string) error) (int, int, map[string]string) {
successCount := 0
failCount := 0
failures := make(map[string]string)
serviceNamesMutex.RLock()
slice := serviceNamesSlice
serviceNamesMutex.RUnlock()
for _, serviceName := range slice {
if err := processFunc(serviceName); err != nil {
g.Log().Errorf(r.GetCtx(), "服务 %s 处理失败: %v", serviceName, err)
failCount++
failures[serviceName] = err.Error()
} else {
successCount++
}
}
return successCount, failCount, failures
}
// CircuitBreakerResetHandler 重置熔断器
func CircuitBreakerResetHandler(r *ghttp.Request) {
serviceName := r.Get("service").String()
if !isAdminIP(r) {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 403, Message: "权限不足,禁止访问"})
return
}
if serviceName == "" || serviceName == "*" {
successCount, failCount, failures := batchProcessServices(r, func(name string) error {
return resetSingleService(r, name)
})
g.Log().Infof(r.GetCtx(), "批量重置熔断器完成: 成功 %d, 失败 %d", successCount, failCount)
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("批量重置完成: 成功 %d, 失败 %d", successCount, failCount),
Data: map[string]interface{}{"success": successCount, "failed": failCount, "failures": failures}})
return
}
if err := resetSingleService(r, serviceName); err != nil {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 500, Message: fmt.Sprintf("重置熔断器失败: %v", err)})
return
}
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("服务 '%s' 的熔断器已重置", serviceName)})
}
// resetSingleService 重置单个服务
func resetSingleService(r *ghttp.Request, serviceName string) error {
resourceName := "service:" + serviceName
if rules := circuitbreaker.GetRulesOfResource(resourceName); len(rules) > 0 {
if _, err := circuitbreaker.LoadRulesOfResource(resourceName, []*circuitbreaker.Rule{}); err != nil {
return err
}
}
if configVal, ok := circuitBreakerConfigs.Load(serviceName); ok {
if err := initServiceCircuitBreaker(serviceName, configVal.(*CircuitBreakerConfig)); err != nil {
return err
}
}
if cbInfoVal, ok := circuitBreakers.Load(serviceName); ok {
cbInfo := cbInfoVal.(*CircuitBreakerInfo)
cbInfo.State.Store(stateClosed)
// 重置指标
cbInfo.Metrics.reset()
cbInfo.WarmupEndTime = time.Now().Add(cbInfo.Config.WarmupDurationParsed).Unix()
cbInfo.Metrics.LastResetTime.Store(time.Now().Unix())
}
if configVal, ok := circuitBreakerConfigs.Load(serviceName); ok {
config, ok := configVal.(*CircuitBreakerConfig)
if ok && config.DistributedTTL > 0 {
redis := g.Redis()
if redis != nil {
if _, err := redis.Del(r.GetCtx(), "circuit_breaker:"+resourceName+":state"); err != nil {
g.Log().Warningf(r.GetCtx(), "清除分布式熔断状态失败: %s, error: %v", resourceName, err)
}
}
}
}
g.Log().Infof(r.GetCtx(), "熔断器已手动重置: %s", resourceName)
return nil
}
// CircuitBreakerReloadHandler 配置重载接口
func CircuitBreakerReloadHandler(r *ghttp.Request) {
serviceName := r.Get("service").String()
if !isAdminIP(r) {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 403, Message: "权限不足,禁止访问"})
return
}
if serviceName == "" || serviceName == "*" {
successCount, failCount, failures := batchProcessServices(r, func(serviceName string) error {
return ReloadCircuitBreakerConfig(serviceName)
})
updateAdminIPsCache()
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("配置重载完成: 成功 %d, 失败 %d", successCount, failCount),
Data: map[string]interface{}{"success": successCount, "failed": failCount, "failures": failures}})
return
}
if err := ReloadCircuitBreakerConfig(serviceName); err != nil {
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 500, Message: fmt.Sprintf("重载失败: %v", err)})
return
}
updateAdminIPsCache()
r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{Code: 200, Message: fmt.Sprintf("服务 '%s' 的熔断器配置已重载", serviceName)})
}
// StateChangeListener 状态变化监听器类型
type StateChangeListener func(serviceName string, fromState, toState CircuitBreakerState)
// RegisterStateChangeListener 注册监听器
func RegisterStateChangeListener(name string, listener StateChangeListener) {
stateChangeListeners.Store(name, listener)
}
// notifyStateChange 通知监听器
func notifyStateChange(serviceName string, fromState, toState CircuitBreakerState) {
stateChangeListeners.Range(func(_, value interface{}) bool {
listener, ok := value.(StateChangeListener)
if ok {
listener(serviceName, fromState, toState)
}
return true
})
}
// registerStateChangeListeners 注册默认监听器
func registerStateChangeListeners() {
if _, exists := stateChangeListenersRegistered.LoadOrStore("default", true); exists {
return
}
RegisterStateChangeListener("default", func(serviceName string, fromState, toState CircuitBreakerState) {
level := "Info"
if toState == StateOpen {
level = "Warning"
}
g.Log().Print(context.Background(), level, fmt.Sprintf("熔断器状态变化: service=%s, %s -> %s", serviceName, fromState, toState))
})
}