diff --git a/beans/beans.go b/beans/beans.go index 31806bd..db34063 100644 --- a/beans/beans.go +++ b/beans/beans.go @@ -3,6 +3,8 @@ package beans import ( "time" + "github.com/gogf/gf/v2/os/gtime" + "go.mongodb.org/mongo-driver/v2/bson" ) @@ -38,17 +40,50 @@ type MongoBaseDO struct { // SQLBaseDO SQL数据库基础实体 type SQLBaseDO struct { - Id uint64 `json:"id"` // 主键ID - Creator string `json:"creator"` // 创建人 - CreatedAt *time.Time `json:"createdAt"` // 创建时间 - Updater string `json:"updater"` // 更新人 - UpdatedAt *time.Time `json:"updatedAt"` // 更新时间 - TenantId string `json:"tenantId"` // 租户ID - IsDeleted bool `json:"isDeleted"` // 是否删除 + Id uint64 `orm:"id" json:"id"` // 主键ID + Bid string `orm:"bid" json:"bid"` // 业务ID + Creator string `orm:"creator" json:"creator"` // 创建人 + CreatedAt *gtime.Time `orm:"created_at" json:"createdAt"` // 创建时间 + Updater string `orm:"updater" json:"updater"` // 更新人 + UpdatedAt *gtime.Time `orm:"updated_at" json:"updatedAt"` // 更新时间 + Deleter string `orm:"deleter" json:"deleter"` // 软删除人 + DeletedAt *gtime.Time `orm:"deleted_at" json:"deletedAt"` // 软删除时间 + IsDeleted bool `orm:"is_deleted" json:"isDeleted"` // 是否删除 +} + +type SQLBaseCol struct { + Id string + Bid string + Creator string + CreatedAt string + Updater string + UpdatedAt string + Deleter string + DeletedAt string + IsDeleted string +} + +var DefSQLBaseCol = SQLBaseCol{ + Id: "id", + Bid: "bid", + Creator: "creator", + CreatedAt: "created_at", + Updater: "updater", + UpdatedAt: "updated_at", + Deleter: "deleter", + DeletedAt: "deleted_at", + IsDeleted: "is_deleted", } type User struct { - UserId interface{} `bson:"userId" json:"userId"` // 用户ID - UserName interface{} `bson:"userName" json:"userName"` // 用户名 - TenantId interface{} `bson:"tenantId" json:"tenantId"` // 租户ID + Id uint64 `orm:"id,primary" json:"id"` // + UserName string `orm:"user_name,unique" json:"userName"` // 用户名 + UserNickname string `orm:"user_nickname" json:"userNickname"` // 用户昵称 + UserPassword string `orm:"user_password" json:"userPassword"` // 登录密码;cmf_password加密 + UserSalt string `orm:"user_salt" json:"userSalt"` // 加密盐 + UserStatus uint `orm:"user_status" json:"userStatus"` // 用户状态;0:禁用,1:正常,2:未验证 + IsAdmin int `orm:"is_admin" json:"isAdmin"` // 是否后台管理员 1 是 0 否 + Avatar string `orm:"avatar" json:"avatar"` //头像 + DeptId uint64 `orm:"dept_id" json:"deptId"` //部门id + TenantId uint64 `orm:"tenant_id" json:"tenantId"` //租户id } diff --git a/db/gfdb/gfdb.go b/db/gfdb/gfdb.go new file mode 100644 index 0000000..152d625 --- /dev/null +++ b/db/gfdb/gfdb.go @@ -0,0 +1,463 @@ +package gfdb + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "regexp" + "strings" + "time" + + "gitea.com/red-future/common/utils" + "github.com/gogf/gf/v2/crypto/gmd5" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/database/gredis" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gcache" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + "go.opentelemetry.io/otel/trace" +) + +// ==================== 缓存管理器(单例) ==================== + +var ( + localCache *gcache.Cache +) + +// getLocalCache 获取本地缓存实例 +func getLocalCache() *gcache.Cache { + if localCache == nil { + localCache = gcache.New() + } + return localCache +} + +// getFromCache 从缓存获取数据(本地缓存 -> Redis) +func getFromCache(ctx context.Context, key string) ([]byte, bool) { + + // 1. 先查本地缓存 + if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil { + if data := val.Bytes(); len(data) > 0 { + return data, true + } + } + + // 2. 再查Redis缓存 + if g.Redis() != nil { + result, err := g.Redis().Get(ctx, key) + if err == nil && !result.IsEmpty() { + data := result.Bytes() + // 写入本地缓存 + err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second) + if err != nil { + return nil, false + } + return data, true + } + } + + return nil, false +} + +// setToCache 写入缓存(本地缓存 + Redis) +func setToCache(ctx context.Context, key string, data []byte) (err error) { + if len(data) == 0 { + return + } + + // 1. 写入本地缓存 + if err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second); err != nil { + return + } + + // 2. 写入Redis缓存 + if g.Redis() != nil { + _, err = g.Redis().Set(ctx, key, data, gredis.SetOption{ + TTLOption: gredis.TTLOption{ + EX: gconv.PtrInt64(g.Cfg().MustGet(ctx, "cache.redisTTL")), + }, + }) + if err != nil { + return + } + } + return +} + +// deleteCacheByPattern 根据模式删除缓存 +func deleteCacheByPattern(ctx context.Context, pattern string) (err error) { + // 1. 删除匹配模式的本地缓存 + localCache := getLocalCache() + keys := localCache.MustKeyStrings(ctx) + if len(keys) > 0 { + for _, key := range keys { + if matchPattern(key, pattern) { + _, err = localCache.Remove(ctx, key) + if err != nil { + return err + } + } + } + } + + // 2. 删除Redis缓存(使用SCAN+DEL) + if g.Redis() != nil { + + keys, err := g.Redis().Keys(ctx, pattern) + if err != nil { + return err + } + for _, key := range keys { + _, err = g.Redis().Del(ctx, key) + if err != nil { + return err + } + } + + } + return nil +} + +// matchPattern 检查 key 是否匹配 Redis SCAN 的 MATCH 模式(支持 * 通配符) +func matchPattern(key string, pattern string) bool { + // 将 Redis 的 MATCH 模式转换为正则表达式 + // 转义正则特殊字符(除了 *) + regexPattern := regexp.QuoteMeta(pattern) + // 将转义后的 \* 替换回 .* + regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*") + // 添加开始和结束锚点 + regexPattern = "^" + regexPattern + "$" + + matched, _ := regexp.MatchString(regexPattern, key) + return matched +} + +// ==================== 统一Hook入口 ==================== + +// CatchSQLHook 返回统一的 HookHandler(包含租户自动赋值和缓存) +// 使用示例: +// +// // 基础使用(自动租户赋值,无缓存) +// g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data) +// +// // 启用缓存(用户无感知,自动处理缓存key) +// ctx = base.WithCacheEnabled(ctx, "asset") +// Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result) +func catchSQLHook() gdb.HookHandler { + return gdb.HookHandler{ + Insert: insertHook, + Update: updateHook, + Delete: deleteHook, + Select: selectHook, + } +} + +// ==================== Insert钩子 ==================== + +func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { + // 1. 自动赋值租户字段 + userInfo, _ := utils.GetUserInfo(ctx) + + //if !g.IsEmpty(userInfo.UserName) { + // in.Model.Data("creator", userInfo.UserName) + // in.Model.Data("updater", userInfo.UserName) + //} + for i := range in.Data { + if !g.IsEmpty(userInfo.UserName) { + if _, ok := in.Data[i]["creator"]; !ok { + in.Data[i]["creator"] = userInfo.UserName + } + if _, ok := in.Data[i]["updater"]; !ok { + in.Data[i]["updater"] = userInfo.UserName + } + } + } + + // 2. 执行插入 + result, err = in.Next(ctx) + if err != nil { + return nil, err + } + + // 3. 清除相关缓存 + if userInfo != nil && userInfo.TenantId != 0 { + if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { + return nil, err + } + } + + return result, nil +} + +// ==================== Update钩子 ==================== + +func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { + // 1. 自动赋值修改人 + userInfo, _ := utils.GetUserInfo(ctx) + + //if !g.IsEmpty(userInfo.UserName) { + // in.Model.Data("creator", userInfo.UserName) + // in.Model.Data("updater", userInfo.UserName) + //} + + switch data := in.Data.(type) { + case gdb.Map: + if !g.IsEmpty(userInfo.UserName) { + if _, ok := data["updater"]; !ok { + data["updater"] = userInfo.UserName + } + } + case gdb.List: + for i := range data { + if !g.IsEmpty(userInfo.UserName) { + if _, ok := data[i]["updater"]; !ok { + data[i]["updater"] = userInfo.UserName + } + } + } + } + + // 2. 执行更新 + result, err = in.Next(ctx) + if err != nil { + return nil, err + } + + // 3. 清除相关缓存 + if userInfo != nil && userInfo.TenantId != 0 { + if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { + return nil, err + } + } + + return result, nil +} + +// ==================== Delete钩子 ==================== + +func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { + // 1. 执行删除 + result, err = in.Next(ctx) + if err != nil { + return nil, err + } + + // 2. 清除相关缓存 + userInfo, _ := utils.GetUserInfo(ctx) + if userInfo != nil && userInfo.TenantId != 0 { + if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil { + return nil, err + } + } + + return result, nil +} + +// ==================== Select钩子(缓存读取) ==================== + +func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + traceID := getTraceID(ctx) + if traceID == "" { + return nil, fmt.Errorf("[DB] GetTraceID error: traceID is empty") + } + enabled, err := gcache.Get(ctx, traceID) + // 未启用缓存,直接执行查询 + if !gconv.Bool(enabled) { + return in.Next(ctx) + } + + // 从 SQL 字符串中提取 WHERE 条件部分 + whereCondition := "" + // 查找 WHERE 关键字(不区分大小写) + whereIndex := gstr.PosI(in.Sql, " WHERE ") + if whereIndex != -1 { + // 提取 WHERE 之后的内容 + whereCondition = in.Sql[whereIndex+7:] + // 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句 + for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} { + if idx := gstr.PosI(whereCondition, keyword); idx != -1 { + whereCondition = whereCondition[:idx] + } + } + } + + user, err := utils.GetUserInfo(ctx) + if err != nil { + return nil, err + } + + encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args)) + if err != nil { + return nil, err + } + + // 构建缓存key:sql:tenantId:table:where条件:args + cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(user.TenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt) + + // 1. 先查缓存 + if data, ok := getFromCache(ctx, cacheKey); ok { + var records gdb.Result + if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 { + return records, nil + } + } + + // 2. 执行数据库查询 + result, err = in.Next(ctx) + if err != nil { + return nil, err + } + + // 3. 写入缓存 + if len(result) > 0 { + if data, err := json.Marshal(result); err == nil { + if err = setToCache(ctx, cacheKey, data); err != nil { + return nil, err + } + } + } + + return result, nil +} + +func getCacheKey(tenantId uint64, table string, isBlur bool) string { + cacheKey := fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table) + if isBlur { + cacheKey = fmt.Sprintf("%s:*", cacheKey) + } + return cacheKey +} + +// getSelectTypeString 将 SelectType 枚举转换为可读字符串 +func getSelectTypeString(selectType gdb.SelectType) string { + switch selectType { + case gdb.SelectTypeDefault: + return "default" + case gdb.SelectTypeCount: + return "count" + case gdb.SelectTypeValue: + return "value" + case gdb.SelectTypeArray: + return "array" + default: + return "unknown" + } +} + +// ==================== 调用方法 ==================== + +var ( + schemaPrefix = "tenant-" +) + +type Gfdb interface { + Model(ctx context.Context, tableNameOrStruct ...any) *model + Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error +} + +type cache interface { + Cache(ctx context.Context) *gdb.Model +} + +type model struct { + *gdb.Model +} + +type dataBase struct { + gdb.DB +} + +func DB(ctx context.Context) Gfdb { + var dbName []string + user, err := utils.GetUserInfo(ctx) + if err != nil { + glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err) + return nil + } + + var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId) + sprintf := fmt.Sprintf("database.%s", schema) + if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() { + dbName = append(dbName, schema) + } else { + dbName = append(dbName, "default") + schema = g.Cfg().MustGet(ctx, "database.default.name").String() + } + + return &dataBase{ + DB: g.DB(dbName...).Schema(schema), + } +} + +func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model { + user, err := utils.GetUserInfo(ctx) + if err != nil { + glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err) + return nil + } + // 创建按地区分库的配置 + shardingConfig := gdb.ShardingConfig{ + Schema: gdb.ShardingSchemaConfig{ + Enable: true, // 启用分库 + Prefix: schemaPrefix, // 分库前缀 + Rule: &RegionShardingRule{RegionMapping: user.TenantId}, // 自定义分库规则 + }, + } + + hook := d.DB.Model(tableNameOrStruct...).Ctx(ctx).Sharding(shardingConfig).ShardingValue(user.TenantId).OmitNilWhere().Hook(catchSQLHook()) + return &model{ + Model: hook, + } +} + +func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error { + return d.DB.Transaction(ctx, f) +} + +func (d *model) Cache(ctx context.Context) *gdb.Model { + traceID := getTraceID(ctx) + if traceID == "" { + glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty") + return nil + } + if err := gcache.Set(ctx, traceID, true, time.Second); err != nil { + glog.Errorf(ctx, "[DB] Cache error: %v", err) + return nil + } + return d.Model +} + +// getTraceID 从 context 中获取链路追踪 ID +func getTraceID(ctx context.Context) string { + span := trace.SpanFromContext(ctx) + if span != nil && span.SpanContext().HasTraceID() { + return span.SpanContext().TraceID().String() + } + return "" +} + +type RegionShardingRule struct { + RegionMapping uint64 +} + +func (r *RegionShardingRule) SchemaName(ctx context.Context, config gdb.ShardingSchemaConfig, value any) (string, error) { + region, ok := value.(uint64) + if !ok { + return "", fmt.Errorf("sharding value must be string for RegionShardingRule") + } + + if r.RegionMapping == region { + return config.Prefix + gconv.String(region), nil + } + + return "default", nil +} + +// TableName 实现分表规则接口 +func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingTableConfig, value any) (string, error) { + // 这里不实现分表,返回空字符串 + return "", nil +} diff --git a/utils/utils.go b/utils/utils.go index 4f741c2..4639379 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/gogf/gf/v2/os/gtime" "net" "reflect" "sort" @@ -18,6 +17,7 @@ import ( "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/util/gconv" "github.com/tiger1103/gfast-token/gftoken" ) @@ -66,67 +66,71 @@ func GetMonthToday(t time.Time, month int) time.Time { return target.AddDate(0, 0, t.Day()-1) } -func GetUserInfo(ctx context.Context) (user beans.User, err error) { - // 检查context是否已取消 - select { - case <-ctx.Done(): - return user, ctx.Err() - default: +func GetUserInfo(ctx context.Context) (user *beans.User, err error) { + // 1. 优先从 context 中获取 + if !g.IsNil(ctx.Value("user")) { + err = gconv.Struct(ctx.Value("user"), &user) + if err != nil { + return user, gerror.Wrap(err, "用户信息转换失败") + } + return } - if !g.IsNil(ctx.Value("id")) || !g.IsNil(ctx.Value("userName")) || !g.IsNil(ctx.Value("tenantId")) { - user.UserId = ctx.Value("id") - user.UserName = ctx.Value("userName") - user.TenantId = ctx.Value("tenantId") - } else { - redisAddr := g.Cfg().MustGet(ctx, "redis.default.address").String() - gft := gftoken.NewGfToken( - gftoken.WithCacheKey("gfToken:"), - gftoken.WithTimeout(20), - gftoken.WithMaxRefresh(10), - gftoken.WithMultiLogin(true), - //gftoken.WithExcludePaths(g.SliceStr{"/excludeDemo"}), - gftoken.WithGRedisConfig(&gredis.Config{ - Address: redisAddr, - Db: 1, - })) - var data *gftoken.CustomClaims - - if !g.IsNil(ctx.Value("token")) { - var tokenData *gftoken.TokenData - tokenData, _, err = gft.GetTokenData(ctx, ctx.Value("token").(string)) + // 2. 从请求头中获取(gateway 转发时设置) + if req := g.RequestFromCtx(ctx); req != nil { + userInfoHeader := req.Header.Get("X-User-Info") + if userInfoHeader != "" { + err = gconv.Struct(userInfoHeader, &user) if err != nil { - return user, gerror.Wrap(err, "token 解析失败") - } - var code int - if data, code = gft.IsNotExpired(tokenData.JwtToken); code != gftoken.JwtTokenOK { - return user, gerror.New("token jwt 解析失败") - } - } else if g.RequestFromCtx(ctx) != nil { - // 解析 token - data, err = gft.ParseToken(g.RequestFromCtx(ctx)) - if err != nil { - return user, gerror.Wrap(err, "token 解析失败") + return user, gerror.Wrap(err, "请求头用户信息解析失败") } + return } - - // 检查 data 是否为 nil - if data == nil { - return user, gerror.New("token 数据为空") - } - // 检查 data.Data 是否为 nil - if data.Data == nil { - g.Log().Errorf(ctx, "data.Data 为空") - return user, gerror.New("用户信息为空") - } - dataMap := gconv.Map(data.Data) - user.UserId = dataMap["id"] - user.UserName = dataMap["userName"] - user.TenantId = dataMap["tenantId"] } - if g.IsNil(user.UserId) && g.IsNil(user.UserName) && g.IsNil(user.TenantId) { - return user, gerror.New("租户信息为空") + // 3. 从 token 解析 + redisAddr := g.Cfg().MustGet(ctx, "redis.default.address").String() + gft := gftoken.NewGfToken( + gftoken.WithCacheKey("gfToken:"), + gftoken.WithTimeout(20), + gftoken.WithMaxRefresh(10), + gftoken.WithMultiLogin(true), + //gftoken.WithExcludePaths(g.SliceStr{"/excludeDemo"}), + gftoken.WithGRedisConfig(&gredis.Config{ + Address: redisAddr, + Db: 1, + })) + var data *gftoken.CustomClaims + + if !g.IsNil(ctx.Value("token")) { + var tokenData *gftoken.TokenData + tokenData, _, err = gft.GetTokenData(ctx, ctx.Value("token").(string)) + if err != nil { + return user, gerror.Wrap(err, "ctx token 解析失败") + } + var code int + if data, code = gft.IsNotExpired(tokenData.JwtToken); code != gftoken.JwtTokenOK { + return user, gerror.New("token jwt 解析失败") + } + } else if g.RequestFromCtx(ctx) != nil { + // 解析 token + data, err = gft.ParseToken(g.RequestFromCtx(ctx)) + if err != nil { + return user, gerror.Wrap(err, "token 解析失败") + } + } + + // 检查 data 是否为 nil + if data == nil { + return user, gerror.New("token 数据为空") + } + // 检查 data.Data 是否为 nil + if data.Data == nil { + return user, gerror.New("用户信息为空") + } + err = gconv.Struct(data.Data, &user) + if err != nil { + return user, gerror.Wrap(err, "用户信息转换失败") } return }