From cc67dd2485b3fc427faeafd44e599c349cff6b96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=96=8C?= <259278618@qq.com> Date: Mon, 29 Dec 2025 14:44:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E9=99=90=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- consts/redis_key.go | 26 +++- middleware/middleware.go | 18 --- middleware/rate_limiter.go | 300 +++++++++++++++++++++++++++++++++++++ redis/redis.go | 11 +- 4 files changed, 325 insertions(+), 30 deletions(-) create mode 100644 middleware/rate_limiter.go diff --git a/consts/redis_key.go b/consts/redis_key.go index 183aa69..3f85dff 100644 --- a/consts/redis_key.go +++ b/consts/redis_key.go @@ -1,7 +1,23 @@ package consts -const CleanList = "list:tenantId-%v:collection-%s:*" -const CleanCount = "count:tenantId-%v:collection-%s:*" -const List = "list:tenantId-%v:collection-%s:filter:%s:options:%s" -const Count = "count:tenantId-%v:collection-%s:filter:%s" -const One = "one:tenantId-%v:collection-%s:filter:%s" +// Redis 数据缓存 Key 常量 +const ( + CleanList = "list:tenantId-%v:collection-%s:*" // 清理列表Key + CleanCount = "count:tenantId-%v:collection-%s:*" // 清理计数Key + List = "list:tenantId-%v:collection-%s:filter:%s:options:%s" // 列表查询Key + Count = "count:tenantId-%v:collection-%s:filter:%s" // 计数查询Key + One = "one:tenantId-%v:collection-%s:filter:%s" // 单条查询Key +) + +// 限流 Redis Key 常量 +const ( + RateLimitKeyPrefix = "ragflow:ratelimit:" // 限流Key前缀 + RateLimitKeyIP = "ip:%s" // IP限流: ip:192.168.1.1 + RateLimitKeyUser = "user:%s" // 用户限流: user:123 或 user:anon:192.168.1.1 + RateLimitKeyService = "service:%s" // 服务限流: service:customerService + RateLimitKeyGlobal = "global:requests" // 全局限流: global:requests + RateLimitKeyOrder = "order:create:%s" // 订单创建限流: order:create:123 + RateLimitKeyTransfer = "wallet:transfer:%s" // 钱包转账限流: wallet:transfer:123 + RateLimitKeyMessage = "cs:message:%s" // 客服消息限流: cs:message:123 + RateLimitKeyUpload = "oss:upload:%s" // 文件上传限流: oss:upload:123 +) diff --git a/middleware/middleware.go b/middleware/middleware.go index d37ba83..9178ddc 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,13 +1,10 @@ package middleware import ( - "context" - "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/net/ghttp" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/text/gstr" - "golang.org/x/time/rate" ) // Logger 中间件 @@ -24,17 +21,6 @@ func Logger(r *ghttp.Request) { ) } -var rateLimit = g.Cfg().MustGet(context.TODO(), "rate.limit").Int() -var rateBurst = g.Cfg().MustGet(context.TODO(), "rate.burst").Int() -var limiter = rate.NewLimiter(rate.Limit(rateLimit), rateBurst) - -func Limiter(r *ghttp.Request) { - if !limiter.Allow() { - r.Response.WriteStatusExit(429) // Return 429 Too Many Requests - r.ExitAll() - } - r.Middleware.Next() -} func Auth(r *ghttp.Request) { //utils.GetUserInfo(r.GetCtx()) token := r.Header.Get("Authorization") @@ -51,7 +37,3 @@ func Auth(r *ghttp.Request) { r.Middleware.Next() } -func validateToken(token string) bool { - // 实现 token 验证逻辑 - return token == "valid-token" -} diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go new file mode 100644 index 0000000..f8c9bc6 --- /dev/null +++ b/middleware/rate_limiter.go @@ -0,0 +1,300 @@ +package middleware + +import ( + "fmt" + "strings" + + "gitee.com/red-future---jilin-g/common/consts" + "gitee.com/red-future---jilin-g/common/redis" + "gitee.com/red-future---jilin-g/common/utils" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/ghttp" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" +) + +// GlobalLimiter 全局限流中间件(使用Redis分布式控制) +func GlobalLimiter(r *ghttp.Request) { + // 从配置文件读取全局限流参数 + globalLimit := g.Cfg().MustGet(r.GetCtx(), "rate.limit", 800).Int64() + + key := consts.RateLimitKeyGlobal + + // 使用Redis计数器进行全局限流 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "全局限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > globalLimit { + g.Log().Warningf(r.GetCtx(), "全局限流触发: count: %d, limit: %d", count, globalLimit) + r.Response.WriteStatusExit(429, "系统当前繁忙,请稍后再试") + return + } + + r.Middleware.Next() +} + +// IPLimiter IP限流中间件(防DDoS) +func IPLimiter(r *ghttp.Request) { + ip := r.GetClientIp() + key := fmt.Sprintf(consts.RateLimitKeyIP, ip) + + // 从配置文件读取IP限流参数 + ipLimit := g.Cfg().MustGet(r.GetCtx(), "rate.ip.limit", 100).Int64() + + // 使用Redis计数器 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) // 1秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "IP限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > ipLimit { + g.Log().Warningf(r.GetCtx(), "IP限流触发: %s, count: %d, limit: %d", ip, count, ipLimit) + r.Response.WriteStatusExit(429, "请求过于频繁,请稍后再试") + return + } + + r.Middleware.Next() +} + +// UserLimiter 用户维度限流中间件(防止单用户滥用) +func UserLimiter(r *ghttp.Request) { + // 从JWT获取用户ID(如果已登录) + var userId string + var isAuth bool = false + + if token := r.Header.Get("Authorization"); token != "" && gstr.HasPrefix(token, "Bearer ") { + // 这里应该解析JWT获取用户ID,简化示例中直接使用token + tokenStr := gstr.SubStrFrom(token, "7") + if tokenStr != "" && validateToken(tokenStr) { + userId = tokenStr + isAuth = true + } + } + + // 如果没有userId,使用IP作为标识 + if userId == "" { + userId = "anon:" + r.GetClientIp() + } + + // 从配置文件读取用户限流参数 + var userLimit int64 + if isAuth { + userLimit = g.Cfg().MustGet(r.GetCtx(), "rate.user.authenticated.limit", 50).Int64() + } else { + userLimit = g.Cfg().MustGet(r.GetCtx(), "rate.user.anonymous.limit", 20).Int64() + } + + key := fmt.Sprintf(consts.RateLimitKeyUser, userId) + count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) + if err != nil { + g.Log().Errorf(r.GetCtx(), "用户限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > userLimit { + userType := "已登录" + if !isAuth { + userType = "未登录" + } + g.Log().Warningf(r.GetCtx(), "用户限流触发: %s, count: %d, limit: %d, type: %s", userId, count, userLimit, userType) + r.Response.WriteStatusExit(429, "您的请求过于频繁,请稍后再试") + return + } + + r.Middleware.Next() +} + +// ServiceLimiter 服务维度限流中间件(保护微服务) +func ServiceLimiter(r *ghttp.Request) { + // 从URL路径提取服务名: /customerService/xxx -> customerService + pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(pathParts) == 0 { + r.Middleware.Next() + return + } + + serverName := pathParts[0] + + // 从配置文件读取服务限流参数 + serviceLimitKey := fmt.Sprintf("rate.services.%s.limit", serverName) + limit := g.Cfg().MustGet(r.GetCtx(), serviceLimitKey, 0).Int64() + + // 如果配置为0,说明该服务没有限流配置,跳过限流 + if limit == 0 { + r.Middleware.Next() + return + } + + key := fmt.Sprintf(consts.RateLimitKeyService, serverName) + count, err := redis.IncrRateLimit(r.GetCtx(), key, 1) + if err != nil { + g.Log().Errorf(r.GetCtx(), "服务限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > limit { + g.Log().Warningf(r.GetCtx(), "服务限流触发: %s, count: %d, limit: %d", serverName, count, limit) + r.Response.WriteStatusExit(429, fmt.Sprintf("服务 '%s' 当前繁忙,请稍后再试", serverName)) + return + } + + r.Middleware.Next() +} + +// OrderCreateLimiter 订单创建限流中间件 +// 限制: 每个用户每分钟最多创建10个订单 +func OrderCreateLimiter(r *ghttp.Request) { + userId := getUserIdFromContext(r) // 从context获取用户ID + if userId == "" { + // 如果无法获取用户信息,跳过限流检查 + r.Middleware.Next() + return + } + + key := fmt.Sprintf(consts.RateLimitKeyOrder, userId) + + // 限制: 每个用户每分钟最多创建10个订单 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "订单创建限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > 10 { + g.Log().Warningf(r.GetCtx(), "订单创建限流触发: %s, count: %d", userId, count) + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 429, + Message: "下单过于频繁,请稍后再试", + }) + return + } + + r.Middleware.Next() +} + +// WalletTransferLimiter 钱包转账限流中间件 +// 限制: 每个用户每分钟最多转账5次 +func WalletTransferLimiter(r *ghttp.Request) { + userId := getUserIdFromContext(r) // 从context获取用户ID + if userId == "" { + r.Middleware.Next() + return + } + + key := fmt.Sprintf(consts.RateLimitKeyTransfer, userId) + + // 限制: 每个用户每分钟最多转账5次 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "钱包转账限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > 5 { + g.Log().Warningf(r.GetCtx(), "钱包转账限流触发: %s, count: %d", userId, count) + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 429, + Message: "转账操作过于频繁,请稍后再试", + }) + return + } + + r.Middleware.Next() +} + +// CSMessageLimiter 客服消息限流中间件 +// 限制: 每个用户每分钟最多发送30条消息 +func CSMessageLimiter(r *ghttp.Request) { + userId := getUserIdFromContext(r) // 从context获取用户ID + if userId == "" { + r.Middleware.Next() + return + } + + key := fmt.Sprintf(consts.RateLimitKeyMessage, userId) + + // 限制: 每个用户每分钟最多发送30条消息 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "客服消息限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > 30 { + g.Log().Warningf(r.GetCtx(), "客服消息限流触发: %s, count: %d", userId, count) + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 429, + Message: "消息发送过于频繁,请稍后再试", + }) + return + } + + r.Middleware.Next() +} + +// OSSUploadLimiter 文件上传限流中间件 +// 限制: 每个用户每分钟最多上传10个文件 +func OSSUploadLimiter(r *ghttp.Request) { + userId := getUserIdFromContext(r) // 从context获取用户ID + if userId == "" { + r.Middleware.Next() + return + } + + key := fmt.Sprintf(consts.RateLimitKeyUpload, userId) + + // 限制: 每个用户每分钟最多上传10个文件 + count, err := redis.IncrRateLimit(r.GetCtx(), key, 60) // 60秒窗口 + if err != nil { + g.Log().Errorf(r.GetCtx(), "文件上传限流Redis错误: %v", err) + r.Middleware.Next() + return + } + + if count > 10 { + g.Log().Warningf(r.GetCtx(), "文件上传限流触发: %s, count: %d", userId, count) + r.Response.WriteJsonExit(ghttp.DefaultHandlerResponse{ + Code: 429, + Message: "文件上传过于频繁,请稍后再试", + }) + return + } + + r.Middleware.Next() +} + +// getUserIdFromContext 从请求上下文中获取用户ID +// 使用项目中已有的utils.GetUserInfo方法 +func getUserIdFromContext(r *ghttp.Request) string { + // 使用项目中已有的utils.GetUserInfo方法获取用户信息 + user, err := utils.GetUserInfo(r.GetCtx()) + if err != nil { + // 如果获取用户信息失败,返回空字符串 + return "" + } + + // 在这个项目中,UserName就是用来标识用户的ID + // 转换为字符串类型 + if user.UserName != nil { + return gconv.String(user.UserName) + } + + return "" +} + +// validateToken 验证token有效性 +func validateToken(token string) bool { + // 实现 token 验证逻辑 + return token == "valid-token" +} diff --git a/redis/redis.go b/redis/redis.go index 552d273..aa4e947 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "gitee.com/red-future---jilin-g/common/consts" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/glog" @@ -478,15 +479,11 @@ func IsUserActive(ctx context.Context, userId string, seconds int64) (bool, erro // ============== 限流相关 ============== -const ( - // RateLimitKeyPrefix 限流计数器 Key 前缀 - RateLimitKeyPrefix = "ragflow:ratelimit:" -) - // IncrRateLimit 增加限流计数器,返回当前计数 +// key: 限流key(需要包含完整路径,如 "ip:192.168.1.1") // windowSeconds: 时间窗口(秒) func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) { - fullKey := RateLimitKeyPrefix + key + fullKey := consts.RateLimitKeyPrefix + key result, err := redisClient.Do(ctx, "INCR", fullKey) if err != nil { return @@ -502,7 +499,7 @@ func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count // GetRateLimit 获取当前限流计数 func GetRateLimit(ctx context.Context, key string) (count int64, err error) { - fullKey := RateLimitKeyPrefix + key + fullKey := consts.RateLimitKeyPrefix + key result, err := redisClient.Get(ctx, fullKey) if err != nil { return