2026-03-17 16:09:19 +08:00
|
|
|
|
package gfdb
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"context"
|
|
|
|
|
|
"database/sql"
|
|
|
|
|
|
"encoding/json"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"regexp"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
"gitea.com/red-future/common/beans"
|
2026-03-17 16:09:19 +08:00
|
|
|
|
"gitea.com/red-future/common/utils"
|
2026-03-19 17:07:01 +08:00
|
|
|
|
"github.com/bwmarrin/snowflake"
|
2026-03-17 16:09:19 +08:00
|
|
|
|
"github.com/gogf/gf/v2/crypto/gmd5"
|
|
|
|
|
|
"github.com/gogf/gf/v2/database/gdb"
|
|
|
|
|
|
"github.com/gogf/gf/v2/database/gredis"
|
|
|
|
|
|
"github.com/gogf/gf/v2/frame/g"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/gcache"
|
|
|
|
|
|
"github.com/gogf/gf/v2/os/glog"
|
|
|
|
|
|
"github.com/gogf/gf/v2/text/gstr"
|
|
|
|
|
|
"github.com/gogf/gf/v2/util/gconv"
|
|
|
|
|
|
"go.opentelemetry.io/otel/trace"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== 缓存管理器(单例) ====================
|
|
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
|
localCache *gcache.Cache
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// getLocalCache 获取本地缓存实例
|
|
|
|
|
|
func getLocalCache() *gcache.Cache {
|
|
|
|
|
|
if localCache == nil {
|
|
|
|
|
|
localCache = gcache.New()
|
|
|
|
|
|
}
|
|
|
|
|
|
return localCache
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getFromCache 从缓存获取数据(本地缓存 -> Redis)
|
|
|
|
|
|
func getFromCache(ctx context.Context, key string) ([]byte, bool) {
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 先查本地缓存
|
|
|
|
|
|
if val, err := getLocalCache().Get(ctx, key); err == nil && val != nil {
|
|
|
|
|
|
if data := val.Bytes(); len(data) > 0 {
|
|
|
|
|
|
return data, true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 再查Redis缓存
|
|
|
|
|
|
if g.Redis() != nil {
|
|
|
|
|
|
result, err := g.Redis().Get(ctx, key)
|
|
|
|
|
|
if err == nil && !result.IsEmpty() {
|
|
|
|
|
|
data := result.Bytes()
|
|
|
|
|
|
// 写入本地缓存
|
|
|
|
|
|
err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, false
|
|
|
|
|
|
}
|
|
|
|
|
|
return data, true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil, false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// setToCache 写入缓存(本地缓存 + Redis)
|
|
|
|
|
|
func setToCache(ctx context.Context, key string, data []byte) (err error) {
|
|
|
|
|
|
if len(data) == 0 {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 1. 写入本地缓存
|
|
|
|
|
|
if err = getLocalCache().Set(ctx, key, data, time.Duration(g.Cfg().MustGet(ctx, "cache.localTTL").Int64())*time.Second); err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 写入Redis缓存
|
|
|
|
|
|
if g.Redis() != nil {
|
|
|
|
|
|
_, err = g.Redis().Set(ctx, key, data, gredis.SetOption{
|
|
|
|
|
|
TTLOption: gredis.TTLOption{
|
|
|
|
|
|
EX: gconv.PtrInt64(g.Cfg().MustGet(ctx, "cache.redisTTL")),
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// deleteCacheByPattern 根据模式删除缓存
|
|
|
|
|
|
func deleteCacheByPattern(ctx context.Context, pattern string) (err error) {
|
|
|
|
|
|
// 1. 删除匹配模式的本地缓存
|
|
|
|
|
|
localCache := getLocalCache()
|
|
|
|
|
|
keys := localCache.MustKeyStrings(ctx)
|
|
|
|
|
|
if len(keys) > 0 {
|
|
|
|
|
|
for _, key := range keys {
|
|
|
|
|
|
if matchPattern(key, pattern) {
|
|
|
|
|
|
_, err = localCache.Remove(ctx, key)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 删除Redis缓存(使用SCAN+DEL)
|
|
|
|
|
|
if g.Redis() != nil {
|
|
|
|
|
|
|
|
|
|
|
|
keys, err := g.Redis().Keys(ctx, pattern)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
for _, key := range keys {
|
|
|
|
|
|
_, err = g.Redis().Del(ctx, key)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// matchPattern 检查 key 是否匹配 Redis SCAN 的 MATCH 模式(支持 * 通配符)
|
|
|
|
|
|
func matchPattern(key string, pattern string) bool {
|
|
|
|
|
|
// 将 Redis 的 MATCH 模式转换为正则表达式
|
|
|
|
|
|
// 转义正则特殊字符(除了 *)
|
|
|
|
|
|
regexPattern := regexp.QuoteMeta(pattern)
|
|
|
|
|
|
// 将转义后的 \* 替换回 .*
|
|
|
|
|
|
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
|
|
|
|
|
|
// 添加开始和结束锚点
|
|
|
|
|
|
regexPattern = "^" + regexPattern + "$"
|
|
|
|
|
|
|
|
|
|
|
|
matched, _ := regexp.MatchString(regexPattern, key)
|
|
|
|
|
|
return matched
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== 统一Hook入口 ====================
|
|
|
|
|
|
|
|
|
|
|
|
// CatchSQLHook 返回统一的 HookHandler(包含租户自动赋值和缓存)
|
|
|
|
|
|
// 使用示例:
|
|
|
|
|
|
//
|
|
|
|
|
|
// // 基础使用(自动租户赋值,无缓存)
|
|
|
|
|
|
// g.DB().Model("user").Hook(base.CatchSQLHook()).Ctx(ctx).Insert(data)
|
|
|
|
|
|
//
|
|
|
|
|
|
// // 启用缓存(用户无感知,自动处理缓存key)
|
|
|
|
|
|
// ctx = base.WithCacheEnabled(ctx, "asset")
|
|
|
|
|
|
// Asset.CtxWithCache(ctx).Where("id", 123).Scan(&result)
|
|
|
|
|
|
func catchSQLHook() gdb.HookHandler {
|
|
|
|
|
|
return gdb.HookHandler{
|
|
|
|
|
|
Insert: insertHook,
|
|
|
|
|
|
Update: updateHook,
|
|
|
|
|
|
Delete: deleteHook,
|
|
|
|
|
|
Select: selectHook,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== Insert钩子 ====================
|
|
|
|
|
|
|
|
|
|
|
|
func insertHook(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) {
|
|
|
|
|
|
|
2026-03-19 17:07:01 +08:00
|
|
|
|
userInfo, err := utils.GetUserInfo(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
node, err := snowflake.NewNode(g.Cfg().MustGet(ctx, "server.workerId").Int64())
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
2026-03-17 16:09:19 +08:00
|
|
|
|
for i := range in.Data {
|
2026-03-19 17:07:01 +08:00
|
|
|
|
if _, ok := in.Data[i]["id"]; ok {
|
|
|
|
|
|
in.Data[i]["id"] = node.Generate().Int64()
|
|
|
|
|
|
}
|
2026-03-24 16:17:22 +08:00
|
|
|
|
if _, ok := in.Data[i]["tenant_id"]; ok {
|
|
|
|
|
|
if !g.IsEmpty(userInfo.TenantId) {
|
|
|
|
|
|
in.Data[i]["tenant_id"] = userInfo.TenantId
|
|
|
|
|
|
} else {
|
|
|
|
|
|
return nil, fmt.Errorf("tenantId cannot be empty")
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
if _, ok := in.Data[i]["creator"]; ok {
|
|
|
|
|
|
if !g.IsEmpty(userInfo.UserName) {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
in.Data[i]["creator"] = userInfo.UserName
|
2026-03-24 16:17:22 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
return nil, fmt.Errorf("user info cannot be empty")
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
2026-03-24 16:17:22 +08:00
|
|
|
|
}
|
|
|
|
|
|
if _, ok := in.Data[i]["updater"]; ok {
|
|
|
|
|
|
if !g.IsEmpty(userInfo.UserName) {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
in.Data[i]["updater"] = userInfo.UserName
|
2026-03-24 16:17:22 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
return nil, fmt.Errorf("user info cannot be empty")
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 执行插入
|
|
|
|
|
|
result, err = in.Next(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 清除相关缓存
|
|
|
|
|
|
if userInfo != nil && userInfo.TenantId != 0 {
|
|
|
|
|
|
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== Update钩子 ====================
|
|
|
|
|
|
|
|
|
|
|
|
func updateHook(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) {
|
|
|
|
|
|
// 1. 自动赋值修改人
|
|
|
|
|
|
userInfo, _ := utils.GetUserInfo(ctx)
|
|
|
|
|
|
|
|
|
|
|
|
switch data := in.Data.(type) {
|
|
|
|
|
|
case gdb.Map:
|
2026-03-24 16:17:22 +08:00
|
|
|
|
if _, ok := data["updater"]; ok {
|
|
|
|
|
|
if !g.IsEmpty(userInfo.UserName) {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
data["updater"] = userInfo.UserName
|
2026-03-24 16:17:22 +08:00
|
|
|
|
} else {
|
|
|
|
|
|
return nil, fmt.Errorf("user info cannot be empty")
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
case gdb.List:
|
|
|
|
|
|
for i := range data {
|
|
|
|
|
|
if !g.IsEmpty(userInfo.UserName) {
|
2026-03-18 09:18:35 +08:00
|
|
|
|
if _, ok := data[i]["updater"]; ok {
|
2026-03-24 16:17:22 +08:00
|
|
|
|
if !g.IsEmpty(userInfo.UserName) {
|
|
|
|
|
|
data[i]["updater"] = userInfo.UserName
|
|
|
|
|
|
} else {
|
|
|
|
|
|
return nil, fmt.Errorf("user info cannot be empty")
|
|
|
|
|
|
}
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 执行更新
|
|
|
|
|
|
result, err = in.Next(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 清除相关缓存
|
|
|
|
|
|
if userInfo != nil && userInfo.TenantId != 0 {
|
|
|
|
|
|
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== Delete钩子 ====================
|
|
|
|
|
|
|
|
|
|
|
|
func deleteHook(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) {
|
|
|
|
|
|
// 1. 执行删除
|
|
|
|
|
|
result, err = in.Next(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 清除相关缓存
|
|
|
|
|
|
userInfo, _ := utils.GetUserInfo(ctx)
|
|
|
|
|
|
if userInfo != nil && userInfo.TenantId != 0 {
|
|
|
|
|
|
if err = deleteCacheByPattern(ctx, getCacheKey(userInfo.TenantId, in.Table, true)); err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== Select钩子(缓存读取) ====================
|
|
|
|
|
|
|
|
|
|
|
|
func selectHook(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
var tenantId uint64
|
|
|
|
|
|
// ===================== 最终版:安全追加租户ID =====================
|
|
|
|
|
|
tenantEnabled, err := gcache.Get(ctx, getTraceID(ctx, noTenantIdKeyPrefix))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
if !gconv.Bool(tenantEnabled) {
|
|
|
|
|
|
user, err := utils.GetUserInfo(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
tenantId = user.TenantId
|
|
|
|
|
|
// 【关键修复】找到 SQL 中第一个出现的 ORDER BY / GROUP BY / LIMIT 等关键字位置
|
|
|
|
|
|
sql := in.Sql
|
|
|
|
|
|
insertPos := len(sql)
|
|
|
|
|
|
keywords := []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"}
|
|
|
|
|
|
for _, kw := range keywords {
|
|
|
|
|
|
if idx := gstr.PosI(sql, kw); idx != -1 {
|
|
|
|
|
|
insertPos = idx
|
|
|
|
|
|
break
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
2026-03-18 09:18:35 +08:00
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
// 【正确拼接】把条件插入到关键字之前,而不是直接拼在最后
|
|
|
|
|
|
condition := " " + beans.DefSQLBaseCol.TenantId + " = ?"
|
|
|
|
|
|
if gstr.Contains(gstr.ToUpper(sql), " WHERE ") {
|
|
|
|
|
|
// 有 WHERE → 加 AND
|
|
|
|
|
|
in.Sql = sql[:insertPos] + " AND" + condition + sql[insertPos:]
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 无 WHERE → 加 WHERE
|
|
|
|
|
|
in.Sql = sql[:insertPos] + " WHERE" + condition + sql[insertPos:]
|
|
|
|
|
|
}
|
|
|
|
|
|
in.Args = append(in.Args, tenantId)
|
|
|
|
|
|
}
|
|
|
|
|
|
// ==================================================================
|
|
|
|
|
|
|
|
|
|
|
|
cacheEnabled, err := gcache.Get(ctx, getTraceID(ctx, cacheKeyPrefix))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
2026-03-17 16:09:19 +08:00
|
|
|
|
// 未启用缓存,直接执行查询
|
2026-04-02 10:37:31 +08:00
|
|
|
|
if !gconv.Bool(cacheEnabled) {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
return in.Next(ctx)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 从 SQL 字符串中提取 WHERE 条件部分
|
|
|
|
|
|
whereCondition := ""
|
|
|
|
|
|
// 查找 WHERE 关键字(不区分大小写)
|
|
|
|
|
|
whereIndex := gstr.PosI(in.Sql, " WHERE ")
|
|
|
|
|
|
if whereIndex != -1 {
|
|
|
|
|
|
// 提取 WHERE 之后的内容
|
|
|
|
|
|
whereCondition = in.Sql[whereIndex+7:]
|
|
|
|
|
|
// 移除 ORDER BY, GROUP BY, HAVING, LIMIT 等后续子句
|
|
|
|
|
|
for _, keyword := range []string{" ORDER BY ", " GROUP BY ", " HAVING ", " LIMIT ", " FOR UPDATE"} {
|
|
|
|
|
|
if idx := gstr.PosI(whereCondition, keyword); idx != -1 {
|
|
|
|
|
|
whereCondition = whereCondition[:idx]
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
encrypt, err := gmd5.Encrypt(fmt.Sprintf("%s:%s", whereCondition, in.Args))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 构建缓存key:sql:tenantId:table:where条件:args
|
2026-04-02 10:37:31 +08:00
|
|
|
|
cacheKey := fmt.Sprintf("%s:%s:%s", getCacheKey(tenantId, in.Table, false), getSelectTypeString(in.SelectType), encrypt)
|
2026-03-17 16:09:19 +08:00
|
|
|
|
|
|
|
|
|
|
// 1. 先查缓存
|
|
|
|
|
|
if data, ok := getFromCache(ctx, cacheKey); ok {
|
|
|
|
|
|
var records gdb.Result
|
|
|
|
|
|
if err := json.Unmarshal(data, &records); err == nil && len(records) > 0 {
|
|
|
|
|
|
return records, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 2. 执行数据库查询
|
|
|
|
|
|
result, err = in.Next(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 3. 写入缓存
|
|
|
|
|
|
if len(result) > 0 {
|
|
|
|
|
|
if data, err := json.Marshal(result); err == nil {
|
|
|
|
|
|
if err = setToCache(ctx, cacheKey, data); err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return result, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func getCacheKey(tenantId uint64, table string, isBlur bool) string {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
var cacheKey string
|
|
|
|
|
|
if g.IsEmpty(tenantId) {
|
|
|
|
|
|
cacheKey = fmt.Sprintf("sql:%s", table)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
cacheKey = fmt.Sprintf("sql:tenantId-%v:%s", tenantId, table)
|
|
|
|
|
|
}
|
2026-03-17 16:09:19 +08:00
|
|
|
|
if isBlur {
|
|
|
|
|
|
cacheKey = fmt.Sprintf("%s:*", cacheKey)
|
|
|
|
|
|
}
|
|
|
|
|
|
return cacheKey
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getSelectTypeString 将 SelectType 枚举转换为可读字符串
|
|
|
|
|
|
func getSelectTypeString(selectType gdb.SelectType) string {
|
|
|
|
|
|
switch selectType {
|
|
|
|
|
|
case gdb.SelectTypeDefault:
|
|
|
|
|
|
return "default"
|
|
|
|
|
|
case gdb.SelectTypeCount:
|
|
|
|
|
|
return "count"
|
|
|
|
|
|
case gdb.SelectTypeValue:
|
|
|
|
|
|
return "value"
|
|
|
|
|
|
case gdb.SelectTypeArray:
|
|
|
|
|
|
return "array"
|
|
|
|
|
|
default:
|
|
|
|
|
|
return "unknown"
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ==================== 调用方法 ====================
|
|
|
|
|
|
|
|
|
|
|
|
var (
|
2026-04-02 10:37:31 +08:00
|
|
|
|
schemaPrefix = "tenant-"
|
|
|
|
|
|
cacheKeyPrefix = "cache-"
|
|
|
|
|
|
noTenantIdKeyPrefix = "tenantId-"
|
2026-03-17 16:09:19 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
type Gfdb interface {
|
2026-04-01 13:38:33 +08:00
|
|
|
|
Exec(ctx context.Context, sql string, args ...any) (sql.Result, error)
|
2026-03-17 16:09:19 +08:00
|
|
|
|
Model(ctx context.Context, tableNameOrStruct ...any) *model
|
|
|
|
|
|
Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type cache interface {
|
|
|
|
|
|
Cache(ctx context.Context) *gdb.Model
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
type noTenantId interface {
|
|
|
|
|
|
NoTenantId(ctx context.Context) *modelCache
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type dataBase struct {
|
|
|
|
|
|
gdb.DB
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
type model struct {
|
|
|
|
|
|
*gdb.Model
|
2026-04-01 13:38:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
type modelCache struct {
|
|
|
|
|
|
*model
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func checkSchemaConfig(ctx context.Context) (uint64, bool) {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
user, err := utils.GetUserInfo(ctx)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
glog.Errorf(ctx, "[DB] GetUserInfo error: %v", err)
|
2026-04-02 10:37:31 +08:00
|
|
|
|
return 0, false
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
var schema = fmt.Sprintf("%s%v", schemaPrefix, user.TenantId)
|
|
|
|
|
|
sprintf := fmt.Sprintf("database.%s", schema)
|
|
|
|
|
|
if !g.Cfg().MustGet(ctx, sprintf).IsEmpty() {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
return user.TenantId, true
|
2026-04-01 13:38:33 +08:00
|
|
|
|
}
|
2026-04-02 10:37:31 +08:00
|
|
|
|
return user.TenantId, false
|
2026-04-01 13:38:33 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func DB(ctx context.Context) Gfdb {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
tenantId, config := checkSchemaConfig(ctx)
|
|
|
|
|
|
|
2026-04-01 13:38:33 +08:00
|
|
|
|
var schema = fmt.Sprintf("%s%v", schemaPrefix, tenantId)
|
|
|
|
|
|
|
|
|
|
|
|
var dbName []string
|
|
|
|
|
|
if config {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
dbName = append(dbName, schema)
|
|
|
|
|
|
} else {
|
|
|
|
|
|
dbName = append(dbName, "default")
|
2026-03-18 09:18:35 +08:00
|
|
|
|
// 配置文件中 default 是数组格式,需要通过索引 0 访问
|
|
|
|
|
|
defaultConfig := g.Cfg().MustGet(ctx, "database.default")
|
|
|
|
|
|
if defaultConfig.IsSlice() {
|
|
|
|
|
|
schema = g.Cfg().MustGet(ctx, "database.default.0.name").String()
|
|
|
|
|
|
} else {
|
|
|
|
|
|
schema = g.Cfg().MustGet(ctx, "database.default.name").String()
|
|
|
|
|
|
}
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return &dataBase{
|
|
|
|
|
|
DB: g.DB(dbName...).Schema(schema),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (d *dataBase) Model(ctx context.Context, tableNameOrStruct ...any) *model {
|
2026-03-18 09:18:35 +08:00
|
|
|
|
|
|
|
|
|
|
m := d.DB.Model(tableNameOrStruct...).Ctx(ctx)
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
tenantId, config := checkSchemaConfig(ctx)
|
|
|
|
|
|
|
2026-04-01 13:38:33 +08:00
|
|
|
|
if config {
|
2026-03-18 09:18:35 +08:00
|
|
|
|
// 创建按地区分库的配置
|
|
|
|
|
|
shardingConfig := gdb.ShardingConfig{
|
|
|
|
|
|
Schema: gdb.ShardingSchemaConfig{
|
2026-04-01 13:38:33 +08:00
|
|
|
|
Enable: true, // 启用分库
|
|
|
|
|
|
Prefix: schemaPrefix, // 分库前缀
|
|
|
|
|
|
Rule: &RegionShardingRule{RegionMapping: tenantId}, // 自定义分库规则
|
2026-03-18 09:18:35 +08:00
|
|
|
|
},
|
|
|
|
|
|
}
|
2026-04-01 13:38:33 +08:00
|
|
|
|
m.Sharding(shardingConfig).ShardingValue(tenantId)
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-28 11:10:02 +08:00
|
|
|
|
m.OmitNil().Hook(catchSQLHook())
|
2026-03-17 16:09:19 +08:00
|
|
|
|
return &model{
|
2026-03-18 09:18:35 +08:00
|
|
|
|
Model: m,
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (d *dataBase) Transaction(ctx context.Context, f func(ctx context.Context, tx gdb.TX) error) error {
|
|
|
|
|
|
return d.DB.Transaction(ctx, f)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (d *model) Cache(ctx context.Context) *gdb.Model {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
traceID := getTraceID(ctx, cacheKeyPrefix)
|
2026-03-17 16:09:19 +08:00
|
|
|
|
if traceID == "" {
|
|
|
|
|
|
glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty")
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := gcache.Set(ctx, traceID, true, time.Second); err != nil {
|
|
|
|
|
|
glog.Errorf(ctx, "[DB] Cache error: %v", err)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return d.Model
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-04-02 10:37:31 +08:00
|
|
|
|
func (d *model) NoTenantId(ctx context.Context) *modelCache {
|
|
|
|
|
|
traceID := getTraceID(ctx, noTenantIdKeyPrefix)
|
|
|
|
|
|
if traceID == "" {
|
|
|
|
|
|
glog.Errorf(ctx, "[DB] GetTraceID error: traceID is empty")
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
if err := gcache.Set(ctx, traceID, true, time.Second); err != nil {
|
|
|
|
|
|
glog.Errorf(ctx, "[DB] Cache error: %v", err)
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
return &modelCache{
|
|
|
|
|
|
&model{
|
|
|
|
|
|
Model: d.Model,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-17 16:09:19 +08:00
|
|
|
|
// getTraceID 从 context 中获取链路追踪 ID
|
2026-04-02 10:37:31 +08:00
|
|
|
|
func getTraceID(ctx context.Context, prefix string) string {
|
2026-03-17 16:09:19 +08:00
|
|
|
|
span := trace.SpanFromContext(ctx)
|
|
|
|
|
|
if span != nil && span.SpanContext().HasTraceID() {
|
2026-04-02 10:37:31 +08:00
|
|
|
|
return fmt.Sprintf("%s%v", prefix, span.SpanContext().TraceID().String())
|
2026-03-17 16:09:19 +08:00
|
|
|
|
}
|
|
|
|
|
|
return ""
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type RegionShardingRule struct {
|
|
|
|
|
|
RegionMapping uint64
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *RegionShardingRule) SchemaName(ctx context.Context, config gdb.ShardingSchemaConfig, value any) (string, error) {
|
|
|
|
|
|
region, ok := value.(uint64)
|
|
|
|
|
|
if !ok {
|
|
|
|
|
|
return "", fmt.Errorf("sharding value must be string for RegionShardingRule")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if r.RegionMapping == region {
|
|
|
|
|
|
return config.Prefix + gconv.String(region), nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return "default", nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TableName 实现分表规则接口
|
|
|
|
|
|
func (r *RegionShardingRule) TableName(ctx context.Context, config gdb.ShardingTableConfig, value any) (string, error) {
|
|
|
|
|
|
// 这里不实现分表,返回空字符串
|
|
|
|
|
|
return "", nil
|
|
|
|
|
|
}
|
2026-04-02 10:37:31 +08:00
|
|
|
|
|
|
|
|
|
|
func GetTablePrefix(ctx context.Context) (prefix string, err error) {
|
|
|
|
|
|
tenantId, config := checkSchemaConfig(ctx)
|
|
|
|
|
|
if config {
|
|
|
|
|
|
sprintf := fmt.Sprintf("database.%s%v.0.prefix", schemaPrefix, tenantId)
|
|
|
|
|
|
prefix = g.Cfg().MustGet(ctx, sprintf).String()
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
defaultConfig := g.Cfg().MustGet(ctx, "database.default")
|
|
|
|
|
|
if defaultConfig.IsSlice() {
|
|
|
|
|
|
prefix = g.Cfg().MustGet(ctx, "database.default.0.prefix").String()
|
|
|
|
|
|
} else {
|
|
|
|
|
|
prefix = g.Cfg().MustGet(ctx, "database.default.prefix").String()
|
|
|
|
|
|
}
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|