package gfdb import ( "context" "database/sql" "encoding/json" "fmt" "regexp" "strings" "time" "gitea.com/red-future/common/beans" "gitea.com/red-future/common/utils" "github.com/bwmarrin/snowflake" "github.com/gogf/gf/v2/crypto/gmd5" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gcache" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" "go.opentelemetry.io/otel/trace" ) // ==================== 缓存管理器(单例) ==================== var ( localCache *gcache.Cache ) // getLocalCache 获取本地缓存实例 func getLocalCache() *gcache.Cache { if localCache == nil { localCache = gcache.New() } return localCache } // getFromCache 从缓存获取数据(本地缓存 -> Redis) func getFromCache(ctx context.Context, key string) ([]byte, bool) { // 1. 先查本地缓存 if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil { if data := val.Bytes(); len(data) > 0 { return data, true } } // 2. 再查Redis缓存 if g.Redis() != nil { result, err := g.Redis().Get(ctx, key) if err == nil && !result.IsEmpty() { data := result.Bytes() // 写入本地缓存 err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL", 60).Int64())*time.Second) if err != nil { return nil, false } return data, true } } return nil, false } // setToCache 写入缓存(本地缓存 + Redis) func setToCache(ctx context.Context, key string, data []byte) (err error) { if len(data) == 0 { return } // 1. 写入本地缓存 if err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL", 60).Int64())*time.Second); err != nil { return } // 2. 写入Redis缓存 if g.Redis() != nil { _, err = g.Redis().Set(ctx, key, data, gredis.SetOption{ TTLOption: gredis.TTLOption{ EX: gconv.PtrInt64(g.Cfg().MustGet(ctx, "cache.redisTTL", 300)), }, }) if err != nil { return } } return } // deleteCacheByPattern 根据模式删除缓存 func deleteCacheByPattern(ctx context.Context, pattern string) (err error) { // 1. 删除匹配模式的本地缓存 localCache := getLocalCache() keys := localCache.MustKeyStrings(ctx) if len(keys) > 0 { for _, key := range keys { if matchPattern(key, pattern) { _, err = localCache.Remove(ctx, key) if err != nil { return err } } } } // 2. 删除Redis缓存(使用SCAN+DEL) if g.Redis() != nil { keys, err := g.Redis().Keys(ctx, pattern) if err != nil { return err } for _, key := range keys { _, err = g.Redis().Del(ctx, key) if err != nil { return err } } } return nil } // matchPattern 检查 key 是否匹配 Redis SCAN 的 MATCH 模式(支持 * 通配符) func matchPattern(key string, pattern string) bool { // 将 Redis 的 MATCH 模式转换为正则表达式 // 转义正则特殊字符(除了 *) regexPattern := regexp.QuoteMeta(pattern) // 将转义后的 \* 替换回 .* regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*") // 添加开始和结束锚点 regexPattern = "^" + regexPattern + "$" matched, _ := regexp.MatchString(regexPattern, key) return matched } // ==================== 统一Hook入口 ==================== // CatchSQLHook 返回统一的 HookHandler(包含租户自动赋值和缓存) // 使用示例: // // // 基础使用(自动租户赋值,无缓存) // g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data) // // // 启用缓存(用户无感知,自动处理缓存key) // ctx = base.WithCacheEnabled(ctx, "asset") // Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result) func catchSQLHook() gdb.HookHandler { return gdb.HookHandler{ Insert: insertHook, Update: updateHook, Delete: deleteHook, Select: selectHook, } } // ==================== Insert钩子 ==================== func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } node, err := snowflake.NewNode(g.Cfg().MustGet(ctx, "server.workerId").Int64()) if err != nil { return nil, err } for i := range in.Data { if _, ok := in.Data[i]["id"]; ok { in.Data[i]["id"] = node.Generate().Int64() } if _, ok := in.Data[i]["tenant_id"]; ok { if !g.IsEmpty(userInfo.TenantId) { in.Data[i]["tenant_id"] = userInfo.TenantId } else { return nil, fmt.Errorf("tenantId cannot be empty") } } if _, ok := in.Data[i]["creator"]; ok { if !g.IsEmpty(userInfo.UserName) { in.Data[i]["creator"] = userInfo.UserName } else { return nil, fmt.Errorf("user info cannot be empty") } } if _, ok := in.Data[i]["updater"]; ok { if !g.IsEmpty(userInfo.UserName) { in.Data[i]["updater"] = userInfo.UserName } else { return nil, fmt.Errorf("user info cannot be empty") } } } // 2. 执行插入 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 清除相关缓存 if userInfo != nil && userInfo.TenantId != 0 { if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { return nil, err } } return result, nil } // ==================== Update钩子 ==================== func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { // 1. 自动赋值修改人 userInfo, _ := utils.GetUserInfo(ctx) switch data := in.Data.(type) { case gdb.Map: if _, ok := data["updater"]; ok { if !g.IsEmpty(userInfo.UserName) { data["updater"] = userInfo.UserName } else { return nil, fmt.Errorf("user info cannot be empty") } } case gdb.List: for i := range data { if !g.IsEmpty(userInfo.UserName) { if _, ok := data[i]["updater"]; ok { if !g.IsEmpty(userInfo.UserName) { data[i]["updater"] = userInfo.UserName } else { return nil, fmt.Errorf("user info cannot be empty") } } } } } // 2. 执行更新 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 清除相关缓存 if userInfo != nil && userInfo.TenantId != 0 { if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { return nil, err } } return result, nil } // ==================== Delete钩子 ==================== func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { // 1. 执行删除 result, err = in.Next(ctx) if err != nil { return nil, err } // 2. 清除相关缓存 userInfo, _ := utils.GetUserInfo(ctx) if userInfo != nil && userInfo.TenantId != 0 { if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { return nil, err } } return result, nil } // ==================== Select钩子(缓存读取) ==================== func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { var tenantId uint64 // ===================== 最终版:安全追加租户ID ===================== tenantEnabled, err := gcache.Get(ctx, getTraceID(ctx, noTenantIdKeyPrefix)) if err != nil { return } if !gconv.Bool(tenantEnabled) { user, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } tenantId = user.TenantId // 【关键修复】找到 SQL 中第一个出现的 ORDER BY / GROUP BY / LIMIT 等关键字位置 sql := in.Sql insertPos := len(sql) keywords := []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} for _, kw := range keywords { if idx := gstr.PosI(sql, kw); idx != -1 { insertPos = idx break } } // 【正确拼接】把条件插入到关键字之前,而不是直接拼在最后 condition := " " + beans.DefSQLBaseCol.TenantId + " = ?" if gstr.Contains(gstr.ToUpper(sql), " WHERE ") { // 有 WHERE → 加 AND in.Sql = sql[:insertPos] + " AND" + condition + sql[insertPos:] } else { // 无 WHERE → 加 WHERE in.Sql = sql[:insertPos] + " WHERE" + condition + sql[insertPos:] } in.Args = append(in.Args, tenantId) } // ================================================================== cacheEnabled, err := gcache.Get(ctx, getTraceID(ctx, cacheKeyPrefix)) if err != nil { return } // 未启用缓存,直接执行查询 if !gconv.Bool(cacheEnabled) { return in.Next(ctx) } // 从 SQL 字符串中提取 WHERE 条件部分 whereCondition := "" // 查找 WHERE 关键字(不区分大小写) whereIndex := gstr.PosI(in.Sql, " WHERE ") if whereIndex != -1 { // 提取 WHERE 之后的内容 whereCondition = in.Sql[whereIndex+7:] // 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句 for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} { if idx := gstr.PosI(whereCondition, keyword); idx != -1 { whereCondition = whereCondition[:idx] } } } encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args)) if err != nil { return nil, err } // 构建缓存key:sql:tenantId:table:where条件:args cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(tenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt) // 1. 先查缓存 if data, ok := getFromCache(ctx, cacheKey); ok { var records gdb.Result if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 { return records, nil } } // 2. 执行数据库查询 result, err = in.Next(ctx) if err != nil { return nil, err } // 3. 写入缓存 if len(result) > 0 { if data, err := json.Marshal(result); err == nil { if err = setToCache(ctx, cacheKey, data); err != nil { return nil, err } } } return result, nil } func getCacheKey(tenantId uint64, table string, isBlur bool) string { var cacheKey string if g.IsEmpty(tenantId) { cacheKey = fmt.Sprintf("sql:%s", table) } else { cacheKey = fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table) } if isBlur { cacheKey = fmt.Sprintf("%s:*", cacheKey) } return cacheKey } // getSelectTypeString 将 SelectType 枚举转换为可读字符串 func getSelectTypeString(selectType gdb.SelectType) string { switch selectType { case gdb.SelectTypeDefault: return "default" case gdb.SelectTypeCount: return "count" case gdb.SelectTypeValue: return "value" case gdb.SelectTypeArray: return "array" default: return "unknown" } } // ==================== 调用方法 ==================== var TablePrefix string var ( schemaPrefix = "tenant-" cacheKeyPrefix = "cache-" noTenantIdKeyPrefix = "tenantId-" ) type Gfdb interface { Exec(ctx context.Context, sql string, args ...any) (sql.Result, error) GetAll(ctx context.Context, sql string, args ...any) (gdb.Result, error) Model(ctx context.Context, tableNameOrStruct ...any) *model Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error } type cache interface { Cache(ctx context.Context) *gdb.Model } type noTenantId interface { NoTenantId(ctx context.Context) *modelCache } type dataBase struct { gdb.DB } type model struct { *gdb.Model } type modelCache struct { *model } func checkSchemaConfig(ctx context.Context) (uint64, bool) { user, err := utils.GetUserInfo(ctx) if err != nil { return 0, false } var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId) sprintf := fmt.Sprintf("database.%s", schema) if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() { return user.TenantId, true } return user.TenantId, false } func DB(ctx context.Context, name ...string) Gfdb { var groupName = gdb.DefaultGroupName if len(name) > 0 && name[0] != "" { groupName = name[0] } else { tenantId, config := checkSchemaConfig(ctx) if config { groupName = fmt.Sprintf("%s%v", schemaPrefix, tenantId) } } db := g.DB(groupName) TablePrefix = db.GetConfig().Prefix return &dataBase{ DB: db, } } func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model { return &model{ Model: d.DB.Model(tableNameOrStruct...).Ctx(ctx).OmitNil().Hook(catchSQLHook()), } } func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error { return d.DB.Transaction(ctx, f) } func (d *model) Cache(ctx context.Context) *gdb.Model { traceID := getTraceID(ctx, cacheKeyPrefix) if traceID == "" { glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty") return nil } if err := gcache.Set(ctx, traceID, true, time.Second); err != nil { glog.Errorf(ctx, "[DB] Cache error: %v", err) return nil } return d.Model } func (d *model) NoTenantId(ctx context.Context) *modelCache { traceID := getTraceID(ctx, noTenantIdKeyPrefix) if traceID == "" { glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty") return nil } if err := gcache.Set(ctx, traceID, true, time.Second); err != nil { glog.Errorf(ctx, "[DB] Cache error: %v", err) return nil } return &modelCache{ &model{ Model: d.Model, }, } } // getTraceID 从 context 中获取链路追踪 ID func getTraceID(ctx context.Context, prefix string) string { span := trace.SpanFromContext(ctx) if span != nil && span.SpanContext().HasTraceID() { return fmt.Sprintf("%s%v", prefix, span.SpanContext().TraceID().String()) } return "" }